def __init__(self, opt):
        self.device = torch.device('cuda')
        self.opt = opt
        self.G = Generator(self.opt['network_G']).to(self.device)
        util.init_weights(self.G, init_type='kaiming', scale=0.1)
        if self.opt['path']['pretrain_G']:
            self.G.load_state_dict(torch.load(self.opt['path']['pretrain_G']),
                                   strict=True)
        self.D = Discriminator(self.opt['network_D']).to(self.device)
        util.init_weights(self.D, init_type='kaiming', scale=1)
        self.FE = VGGFeatureExtractor().to(self.device)
        self.G.train()
        self.D.train()
        self.FE.eval()

        self.log_dict = OrderedDict()

        self.optim_params = [
            v for k, v in self.G.named_parameters() if v.requires_grad
        ]
        self.opt_G = torch.optim.Adam(self.optim_params,
                                      lr=self.opt['train']['lr_G'],
                                      betas=(self.opt['train']['b1_G'],
                                             self.opt['train']['b2_G']))
        self.opt_D = torch.optim.Adam(self.D.parameters(),
                                      lr=self.opt['train']['lr_D'],
                                      betas=(self.opt['train']['b1_D'],
                                             self.opt['train']['b2_D']))

        self.optimizers = [self.opt_G, self.opt_D]
        self.schedulers = [
            lr_scheduler.MultiStepLR(optimizer, self.opt['train']['lr_steps'],
                                     self.opt['train']['lr_gamma'])
            for optimizer in self.optimizers
        ]
Beispiel #2
0
    def __init__(self):
        logger.info('Set Data Loader')
        self.dataset = FoodDataset(transform=transforms.Compose([ToTensor()]))
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=num_workers,
                                                       drop_last=True)
        checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path)
        if checkpoint == None:
            logger.info(
                'Don\'t have pre-trained model. Ignore loading model process.')
            logger.info('Set Generator and Discriminator')
            self.G = Generator(tag=tag_size).to(device)
            self.D = Discriminator(tag=tag_size).to(device)
            logger.info('Initialize Weights')
            self.G.apply(initital_network_weights).to(device)
            self.D.apply(initital_network_weights).to(device)
            logger.info('Set Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.epoch = 0
        else:
            logger.info('Load Generator and Discriminator')
            self.G = Generator(tag=tag_size).to(device)
            self.D = Discriminator(tag=tag_size).to(device)
            logger.info('Load Pre-Trained Weights From Checkpoint'.format(
                checkpoint_name))
            self.G.load_state_dict(checkpoint['G'])
            self.D.load_state_dict(checkpoint['D'])
            logger.info('Load Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_D'])

            self.epoch = checkpoint['epoch']
        logger.info('Set Criterion')
        self.a_D = alexnet.alexnet(num_classes=tag_size).to(device)
        self.optimizer_a_D = torch.optim.Adam(self.a_D.parameters(),
                                              lr=learning_rate,
                                              betas=(beta_1, .999))
Beispiel #3
0
    def _build_model(self):
        device = torch.device('cuda')

        data_dimension = self.config.data['dimension']
        generator_hidden_layers = self.config.model['generator_hidden_layers']
        use_dropout = self.config.model['use_dropout']
        drop_prob = self.config.model['drop_prob']
        use_ac_func = self.config.model['use_ac_func']
        activation = self.config.model['activation']
        disc_hidden_layers = self.config.model['disc_hidden_layers']

        logger.log("Loading {} network ...".format(colored('generator',
                                                           'red')))
        gen_fc_layers = [
            self.latent_dim, *generator_hidden_layers, data_dimension
        ]
        generator = Generator(gen_fc_layers, use_dropout, drop_prob,
                              use_ac_func, activation).to(device)

        logger.log("Loading {} network ...".format(
            colored('discriminator', 'red')))
        disc_fc_layers = [data_dimension, *disc_hidden_layers, 1]
        discriminator = Discriminator(disc_fc_layers, use_dropout, drop_prob,
                                      use_ac_func, activation).to(device)

        wandb.watch([generator, discriminator])

        g_optimizer, d_optimizer = self._setup_optimizers(
            generator, discriminator)

        return generator, discriminator, g_optimizer, d_optimizer
Beispiel #4
0
    def __init__(self, args):
        super(SegTransferModel, self).__init__()
        # n_classes for Fundus: 1
        # n_classes for OCT: 12
        self.args = args
        assert args.data_modality in [
            'oct', 'fundus'
        ], 'error in seg_mode, got {}'.format(args.data_modality)

        # model on gpu
        if self.args.data_modality == 'fundus':
            model_G = UNet_4mp(n_channels=1, n_classes=1)
        else:
            model_G = UNet_4mp(n_channels=1, n_classes=12)
        model_D = Discriminator(in_channels=1)

        model_G = nn.DataParallel(model_G).cuda()
        model_D = nn.DataParallel(model_D).cuda()

        l1_loss = nn.L1Loss().cuda()
        nll_loss = nn.NLLLoss().cuda()
        adversarial_loss = AdversarialLoss().cuda()

        self.add_module('model_G', model_G)
        self.add_module('model_D', model_D)

        self.add_module('l1_loss', l1_loss)
        self.add_module('nll_loss', nll_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        # optimizer
        self.optimizer_G = torch.optim.Adam(params=self.model_G.parameters(),
                                            lr=args.lr,
                                            weight_decay=args.weight_decay,
                                            betas=(args.b1, args.b2))
        self.optimizer_D = torch.optim.Adam(params=self.model_D.parameters(),
                                            lr=args.lr * args.d2g_lr,
                                            weight_decay=args.weight_decay,
                                            betas=(args.b1, args.b2))

        # Optionally resume from a checkpoint
        if self.args.resume:
            ckpt_root = os.path.join(args.output_root, args.project,
                                     'checkpoints')
            ckpt_path = os.path.join(ckpt_root, args.resume)
            if os.path.isfile(ckpt_path):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(ckpt_path)
                args.start_epoch = checkpoint['epoch']
                self.model_G.load_state_dict(checkpoint['state_dict_G'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                raise ValueError("=> no checkpoint found at '{}'".format(
                    args.resume))
    def __init__(self, args, ablation_mode=4):
        super(PNetModel, self).__init__()
        self.args = args
        """
        ablation study mode
        """
        # 0: output_structure                       (1 feature)
        # 2: image (1 feature), i.e. auto-encoder
        # 4: output_structure + image               (2 features)

        # model on gpu
        if self.args.data_modality == 'fundus':
            model_G1 = Strcutre_Extraction_Network(n_channels=1, n_classes=1)
        else:
            model_G1 = Strcutre_Extraction_Network(n_channels=1, n_classes=12)
        model_G2 = Image_Reconstruction_Network(
            in_ch=1,
            modality=self.args.data_modality,
            ablation_mode=ablation_mode)
        model_D = Discriminator(in_channels=1)

        model_G1 = nn.DataParallel(model_G1).cuda()
        model_G2 = nn.DataParallel(model_G2).cuda()
        model_D = nn.DataParallel(model_D).cuda()

        l1_loss = nn.L1Loss().cuda()
        adversarial_loss = AdversarialLoss().cuda()

        self.add_module('model_G1', model_G1)
        self.add_module('model_G2', model_G2)
        self.add_module('model_D', model_D)

        self.add_module('l1_loss', l1_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        # optimizer
        self.optimizer_G = torch.optim.Adam(params=self.model_G2.parameters(),
                                            lr=args.lr,
                                            weight_decay=args.weight_decay,
                                            betas=(args.b1, args.b2))
        self.optimizer_D = torch.optim.Adam(params=self.model_D.parameters(),
                                            lr=args.lr * args.d2g_lr,
                                            weight_decay=args.weight_decay,
                                            betas=(args.b1, args.b2))

        # load 1-st ckpts
        if self.args.server == 'ai':
            seg_ckpt_root = os.path.join('/root/workspace', args.project,
                                         'save_models')
        else:
            seg_ckpt_root = os.path.join('/home/imed/new_disk/workspace',
                                         args.project, 'save_models')
        if self.args.data_modality == 'fundus':
            if self.args.DA_ablation_mode_isee == 0:
                _g_zero_point = '0'
            elif self.args.DA_ablation_mode_isee == 0.001:
                _g_zero_point = '001'
            elif self.args.DA_ablation_mode_isee == 0.0001:
                # this is the default
                _g_zero_point = '0001'
            else:
                raise NotImplementedError('error')

            seg_ckpt_path = os.path.join(
                seg_ckpt_root,
                '1st_fundus_seg_g_{}.pth.tar'.format(_g_zero_point))

            ## orginal seg mdel
            # seg_ckpt_path = os.path.join(seg_ckpt_root, '1st_fundus_seg_vgg.pth.tar')
        else:
            seg_ckpt_path = os.path.join(seg_ckpt_root, '1st_oct_seg.pth.tar')

        if os.path.isfile(seg_ckpt_path):
            print("=> loading G1 checkpoint")
            checkpoint = torch.load(seg_ckpt_path)
            self.model_G1.load_state_dict(checkpoint['state_dict_G'])
            print("=> loaded G1 checkpoint (epoch {}) \n from {}".format(
                checkpoint['epoch'], seg_ckpt_path))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(seg_ckpt_path))

        # Optionally resume from a checkpoint
        if self.args.resume:
            ckpt_root = os.path.join(self.args.output_root, args.project,
                                     'checkpoints')
            ckpt_path = os.path.join(ckpt_root, args.resume)
            if os.path.isfile(ckpt_path):
                print("=> loading G2 checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(ckpt_path)
                args.start_epoch = checkpoint['epoch']
                self.model_G2.load_state_dict(checkpoint['state_dict_G'])
                print("=> loaded G2 checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))
def main(args):
    env = gym.make('CartPole-v0')
    env.seed(0)
    ob_space = env.observation_space
    Policy = Policy_net('policy', env)
    Old_Policy = Policy_net('old_policy', env)
    PPO = PPOTrain(Policy, Old_Policy, gamma=args.gamma)
    D = Discriminator(env)

    expert_observations = np.genfromtxt('trajectory/observations.csv')
    expert_actions = np.genfromtxt('trajectory/actions.csv', dtype=np.int32)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        writer = tf.summary.FileWriter(args.logdir, sess.graph)
        sess.run(tf.global_variables_initializer())

        obs = env.reset()
        success_num = 0

        for iteration in range(args.iteration):
            observations = []
            actions = []
            # do NOT use rewards to update policy
            rewards = []
            v_preds = []
            run_policy_steps = 0
            while True:
                run_policy_steps += 1
                obs = np.stack([obs]).astype(dtype=np.float32)  # prepare to feed placeholder Policy.obs
                act, v_pred = Policy.act(obs=obs, stochastic=True)

                act = np.asscalar(act)
                v_pred = np.asscalar(v_pred)
                next_obs, reward, done, info = env.step(act)

                observations.append(obs)
                actions.append(act)
                rewards.append(reward)
                v_preds.append(v_pred)

                if done:
                    next_obs = np.stack([next_obs]).astype(dtype=np.float32)  # prepare to feed placeholder Policy.obs
                    _, v_pred = Policy.act(obs=next_obs, stochastic=True)
                    v_preds_next = v_preds[1:] + [np.asscalar(v_pred)]
                    obs = env.reset()
                    break
                else:
                    obs = next_obs

            writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_length', simple_value=run_policy_steps)])
                               , iteration)
            writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='episode_reward', simple_value=sum(rewards))])
                               , iteration)

            if sum(rewards) >= 195:
                success_num += 1
                if success_num >= 100:
                    saver.save(sess, args.savedir + '/model.ckpt')
                    print('Clear!! Model saved.')
                    break
            else:
                success_num = 0

            # convert list to numpy array for feeding tf.placeholder
            observations = np.reshape(observations, newshape=[-1] + list(ob_space.shape))
            actions = np.array(actions).astype(dtype=np.int32)

            # train discriminator
            for i in range(2):
                D.train(expert_s=expert_observations,
                        expert_a=expert_actions,
                        agent_s=observations,
                        agent_a=actions)

            # output of this discriminator is reward
            d_rewards = D.get_rewards(agent_s=observations, agent_a=actions)
            d_rewards = np.reshape(d_rewards, newshape=[-1]).astype(dtype=np.float32)

            gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next)
            gaes = np.array(gaes).astype(dtype=np.float32)
            # gaes = (gaes - gaes.mean()) / gaes.std()
            v_preds_next = np.array(v_preds_next).astype(dtype=np.float32)

            # train policy
            inp = [observations, actions, gaes, d_rewards, v_preds_next]
            PPO.assign_policy_parameters()
            for epoch in range(6):
                sample_indices = np.random.randint(low=0, high=observations.shape[0],
                                                   size=32)  # indices are in [low, high)
                sampled_inp = [np.take(a=a, indices=sample_indices, axis=0) for a in inp]  # sample training data
                PPO.train(obs=sampled_inp[0],
                          actions=sampled_inp[1],
                          gaes=sampled_inp[2],
                          rewards=sampled_inp[3],
                          v_preds_next=sampled_inp[4])

            summary = PPO.get_summary(obs=inp[0],
                                      actions=inp[1],
                                      gaes=inp[2],
                                      rewards=inp[3],
                                      v_preds_next=inp[4])

            writer.add_summary(summary, iteration)
        writer.close()
Beispiel #7
0
class SRGAN():
    def __init__(self):
        logger.info('Set Data Loader')
        self.dataset = AnimeFaceDataset(
            avatar_tag_dat_path=avatar_tag_dat_path,
            transform=transforms.Compose([ToTensor()]))
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=num_workers,
                                                       drop_last=True)
        checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path)
        if checkpoint == None:
            logger.info(
                'Don\'t have pre-trained model. Ignore loading model process.')
            logger.info('Set Generator and Discriminator')
            self.G = Generator().to(device)
            self.D = Discriminator().to(device)
            logger.info('Initialize Weights')
            self.G.apply(initital_network_weights).to(device)
            self.D.apply(initital_network_weights).to(device)
            logger.info('Set Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.epoch = 0
        else:
            logger.info('Load Generator and Discriminator')
            self.G = Generator().to(device)
            self.D = Discriminator().to(device)
            logger.info('Load Pre-Trained Weights From Checkpoint'.format(
                checkpoint_name))
            self.G.load_state_dict(checkpoint['G'])
            self.D.load_state_dict(checkpoint['D'])
            logger.info('Load Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_D'])
            self.epoch = checkpoint['epoch']
        logger.info('Set Criterion')
        self.label_criterion = nn.BCEWithLogitsLoss().to(device)
        self.tag_criterion = nn.MultiLabelSoftMarginLoss().to(device)

    def load_checkpoint(self, model_dir):
        models_path = read_newest_model(model_dir)
        if len(models_path) == 0:
            return None, None
        models_path.sort()
        new_model_path = os.path.join(model_dump_path, models_path[-1])
        checkpoint = torch.load(new_model_path)
        return checkpoint, new_model_path

    def train(self):
        iteration = -1
        label = Variable(torch.FloatTensor(batch_size, 1.0)).to(device)
        logging.info('Current epoch: {}. Max epoch: {}.'.format(
            self.epoch, max_epoch))
        while self.epoch <= max_epoch:
            # dump checkpoint
            torch.save(
                {
                    'epoch': self.epoch,
                    'D': self.D.state_dict(),
                    'G': self.G.state_dict(),
                    'optimizer_D': self.optimizer_D.state_dict(),
                    'optimizer_G': self.optimizer_G.state_dict(),
                }, '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                 str(self.epoch).zfill(4)))
            logger.info('Checkpoint saved in: {}'.format(
                '{}/checkpoint_{}.tar'.format(model_dump_path,
                                              str(self.epoch).zfill(4))))

            msg = {}
            adjust_learning_rate(self.optimizer_G, iteration)
            adjust_learning_rate(self.optimizer_D, iteration)
            for i, (avatar_tag, avatar_img) in enumerate(self.data_loader):
                iteration += 1
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['epoch'] = int(self.epoch)
                        msg['step'] = int(i)
                        msg['iteration'] = iteration
                avatar_img = Variable(avatar_img).to(device)
                avatar_tag = Variable(torch.FloatTensor(avatar_tag)).to(device)
                # D : G = 2 : 1
                # 1. Training D
                # 1.1. use really image for discriminating
                self.D.zero_grad()
                label_p, tag_p = self.D(avatar_img)
                label.data.fill_(1.0)

                # 1.2. real image's loss
                real_label_loss = self.label_criterion(label_p, label)
                real_tag_loss = self.tag_criterion(tag_p, avatar_tag)
                real_loss_sum = real_label_loss * lambda_adv / 2.0 + real_tag_loss * lambda_adv / 2.0
                real_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator real loss'] = float(real_loss_sum)

                # 1.3. use fake image for discriminating
                g_noise, fake_tag = fake_generator()
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat).detach()
                fake_label_p, fake_tag_p = self.D(fake_img)
                label.data.fill_(.0)

                # 1.4. fake image's loss
                fake_label_loss = self.label_criterion(fake_label_p, label)
                fake_tag_loss = self.tag_criterion(fake_tag_p, fake_tag)
                fake_loss_sum = fake_label_loss * lambda_adv / 2.0 + fake_tag_loss * lambda_adv / 2.0
                fake_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator fake loss'] = float(fake_loss_sum)

                # 1.5. gradient penalty
                # https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py
                alpha_size = [1] * avatar_img.dim()
                alpha_size[0] = avatar_img.size(0)
                alpha = torch.rand(alpha_size).to(device)
                x_hat = Variable(alpha * avatar_img.data + (1 - alpha) * \
                                 (avatar_img.data + 0.5 * avatar_img.data.std() * Variable(torch.rand(avatar_img.size())).to(device)),
                                 requires_grad=True).to(device)
                pred_hat, pred_tag = self.D(x_hat)
                gradients = grad(outputs=pred_hat,
                                 inputs=x_hat,
                                 grad_outputs=torch.ones(
                                     pred_hat.size()).to(device),
                                 create_graph=True,
                                 retain_graph=True,
                                 only_inputs=True)[0].view(x_hat.size(0), -1)
                gradient_penalty = lambda_gp * (
                    (gradients.norm(2, dim=1) - 1)**2).mean()
                gradient_penalty.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator gradient penalty'] = float(
                            gradient_penalty)

                # 1.6. update optimizer
                self.optimizer_D.step()

                # 2. Training G
                # 2.1. generate fake image
                self.G.zero_grad()
                g_noise, fake_tag = fake_generator()
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat)
                fake_label_p, fake_tag_p = self.D(fake_img)
                label.data.fill_(1.0)

                # 2.2. calc loss
                label_loss_g = self.label_criterion(fake_label_p, label)
                tag_loss_g = self.tag_criterion(fake_tag_p, fake_tag)
                loss_g = label_loss_g * lambda_adv / 2.0 + tag_loss_g * lambda_adv / 2.0
                loss_g.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['generator loss'] = float(loss_g)

                # 2.2. update optimizer
                self.optimizer_G.step()

                if verbose:
                    if iteration % verbose_T == 0:
                        logger.info(
                            '------------------------------------------')
                        for key in msg.keys():
                            logger.info('{} : {}'.format(key, msg[key]))
                # save intermediate file
                if iteration % verbose_T == 0:
                    vutils.save_image(
                        avatar_img.data.view(batch_size, 3, avatar_img.size(2),
                                             avatar_img.size(3)),
                        os.path.join(
                            tmp_path, 'real_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    g_noise, fake_tag = fake_generator()
                    fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                    fake_img = self.G(fake_feat)
                    vutils.save_image(
                        fake_img.data.view(batch_size, 3, avatar_img.size(2),
                                           avatar_img.size(3)),
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    logger.info('Saved intermediate file in {}'.format(
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8)))))
            self.epoch += 1
 def get_discriminator(self, task_id):
     discriminator = Discriminator(self.args, task_id).to(self.args.device)
     return discriminator
Beispiel #9
0
    RANDOM_ALL = True
    PRECROP = True if DATASET.lower() == 'ntu' else False
    VP_VALUE_COUNT = 1 if DATASET.lower() == 'ntu' else 3
    CLOSE_VIEWS = True if DATASET.lower() == 'panoptic' else False

    vgg_weights_path, i3d_weights_path, gen_weights_path, disc_weights_path = pretrained_weights_config()

    # generator
    generator = FullNetwork(vp_value_count=VP_VALUE_COUNT, stdev=STDEV,
                            output_shape=(BATCH_SIZE, CHANNELS, FRAMES, HEIGHT, WIDTH),
                            pretrained=True, vgg_weights_path=vgg_weights_path, i3d_weights_path=i3d_weights_path)
    if GEN_PRETRAINED:
        generator.load_state_dict(torch.load(gen_weights_path))
    generator = generator.to(device)
    # discriminator
    discriminator = Discriminator(in_channels=3, pretrained=GEN_PRETRAINED, weights_path=disc_weights_path)
    discriminator = discriminator.to(device)

    if device == 'cuda':
        net = torch.nn.DataParallel(generator)
        cudnn.benchmark = True

    # Loss functions
    criterion = nn.MSELoss()
    adversarial_loss = nn.BCELoss()
    perceptual_loss = vgg16().to(device)
    # categorical_loss = torch.nn.CrossEntropyLoss()
    # continuous_loss = torch.nn.MSELoss()

    optimizer_G = optim.Adam(generator.parameters(), lr=LR)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LR)
Beispiel #10
0
else:
    device = torch.device("cpu")

random.seed(args.SEED)
torch.manual_seed(args.SEED)
if args.USE_CUDA:
    torch.cuda.manual_seed_all(args.SEED)

# Initialize Generator
generator = Generator(args.GAN_TYPE, args.ZDIM, args.NUM_CLASSES)
generator.apply(weights_init)
generator.to(device)
print(generator)

# Initialize Discriminator
discriminator = Discriminator(args.GAN_TYPE, args.NUM_CLASSES)
discriminator.apply(weights_init)
discriminator.to(device)
print(discriminator)

# Initialize loss function and optimizer
criterionLabel = nn.BCELoss()
criterionClass = nn.CrossEntropyLoss()
optimizerD = Adam(discriminator.parameters(), lr=args.LR, betas=(0.5, 0.999))
optimizerG = Adam(generator.parameters(), lr=args.LR, betas=(0.5, 0.999))

# Prepare the noise for evaluation during training phase
fixedNoise = torch.FloatTensor(args.BATCH_SIZE, args.ZDIM, 1, 1).normal_(0, 1)
if args.GAN_TYPE in ["CGAN", "ACGAN"]:
    fixedClass = F.one_hot(torch.LongTensor([i % args.NUM_CLASSES for i in range(args.BATCH_SIZE)]), num_classes=args.NUM_CLASSES)
    fixedConstraint = fixedClass.unsqueeze(-1).unsqueeze(-1)
Beispiel #11
0
                                               beta_1=0.5,
                                               beta_2=0.999)
optimizer_keypoint_detector = tf.keras.optimizers.Adam(learning_rate=lr,
                                                       beta_1=0.5,
                                                       beta_2=0.999)
optimizer_discriminator = tf.keras.optimizers.Adam(learning_rate=lr,
                                                   beta_1=0.5,
                                                   beta_2=0.999)

batch_size = 20
epochs = 150
train_steps = 99  # change

keypoint_detector = KeypointDetector()
generator = Generator()
discriminator = Discriminator()

generator_full = FullGenerator(keypoint_detector, generator, discriminator)
discriminator_full = FullDiscriminator(discriminator)


@tf.function
def train_step(source_images, driving_images):
    with tf.GradientTape(persistent=True) as tape:
        losses_generator, generated = generator_full(source_images,
                                                     driving_images, tape)
        generator_loss = tf.math.reduce_sum(list(losses_generator.values()))

    generator_gradients = tape.gradient(generator_loss,
                                        generator_full.trainable_variables)
    keypoint_detector_gradients = tape.gradient(
Beispiel #12
0
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    RANDOM_ALL = True
    PRECROP = True if DATASET.lower() == 'ntu' else False
    VP_VALUE_COUNT = 1 if DATASET.lower() == 'ntu' else 3
    CLOSE_VIEWS = True if DATASET.lower() == 'panoptic' else False

    # vgg_weights_path, i3d_weights_path = pretrained_weights_config()

    # generator
    generator = FullNetwork(vp_value_count=VP_VALUE_COUNT, stdev=STDEV,
                            output_shape=(BATCH_SIZE, CHANNELS, FRAMES, HEIGHT, WIDTH))
                            # pretrained=True, vgg_weights_path=vgg_weights_path, i3d_weights_path=i3d_weights_path)
    generator = generator.to(device)
    # discriminator
    discriminator = Discriminator(in_channels=3)
    discriminator = discriminator.to(device)

    if device == 'cuda':
        net = torch.nn.DataParallel(generator)
        cudnn.benchmark = True

    # Loss functions
    criterion = nn.MSELoss()
    adversarial_loss = nn.BCELoss()
    # categorical_loss = torch.nn.CrossEntropyLoss()
    # continuous_loss = torch.nn.MSELoss()

    optimizer_G = optim.Adam(generator.parameters(), lr=LR)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LR)
Beispiel #13
0
    test_noise = [[]]
    img = G(torch.Tensor(test_noise).to(device))

    t_img = vutils.make_grid(img.data.view(1, 3, 128, 128)).numpy()
    t_img = np.transpose(t_img, (1, 2, 0))
    t_img[t_img < 0] = 0
    # min_max_scaler = preprocessing.MinMaxScaler()
    # t_img[..., 0] = min_max_scaler.fit_transform(t_img[..., 0])
    # t_img[..., 1] = min_max_scaler.fit_transform(t_img[..., 1])
    # t_img[..., 2] = min_max_scaler.fit_transform(t_img[..., 2])
    plt.imshow(t_img)
    plt.show()
    label_p, tag_p = D(img)
    label = Variable(torch.FloatTensor(1, 1.0)).to(device)
    lbl_criterion = nn.BCEWithLogitsLoss().to(device)
    loss = lbl_criterion(label_p, label) * 10000
    loss.backward()
    print(grad_list[0][0].shape)


if __name__ == '__main__':
    checkpoint, _ = load_checkpoint(model_dump_path)

    G = Generator().to(device)
    G.load_state_dict(checkpoint['G'])

    D = Discriminator().to(device)
    D.load_state_dict(checkpoint['D'])

    # generate(G, 'test', ['white hair'])
    image_backward_D(G, D)
Beispiel #14
0
class SRGAN():
    def __init__(self):
        logger.info('Set Data Loader')
        self.dataset = FoodDataset(transform=transforms.Compose([ToTensor()]))
        self.data_loader = torch.utils.data.DataLoader(self.dataset,
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       num_workers=num_workers,
                                                       drop_last=True)
        checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path)
        if checkpoint == None:
            logger.info(
                'Don\'t have pre-trained model. Ignore loading model process.')
            logger.info('Set Generator and Discriminator')
            self.G = Generator(tag=tag_size).to(device)
            self.D = Discriminator(tag=tag_size).to(device)
            logger.info('Initialize Weights')
            self.G.apply(initital_network_weights).to(device)
            self.D.apply(initital_network_weights).to(device)
            logger.info('Set Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.epoch = 0
        else:
            logger.info('Load Generator and Discriminator')
            self.G = Generator(tag=tag_size).to(device)
            self.D = Discriminator(tag=tag_size).to(device)
            logger.info('Load Pre-Trained Weights From Checkpoint'.format(
                checkpoint_name))
            self.G.load_state_dict(checkpoint['G'])
            self.D.load_state_dict(checkpoint['D'])
            logger.info('Load Optimizers')
            self.optimizer_G = torch.optim.Adam(self.G.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.D.parameters(),
                                                lr=learning_rate,
                                                betas=(beta_1, 0.999))
            self.optimizer_G.load_state_dict(checkpoint['optimizer_G'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_D'])

            self.epoch = checkpoint['epoch']
        logger.info('Set Criterion')
        self.a_D = alexnet.alexnet(num_classes=tag_size).to(device)
        self.optimizer_a_D = torch.optim.Adam(self.a_D.parameters(),
                                              lr=learning_rate,
                                              betas=(beta_1, .999))
        # self.label_criterion = nn.BCEWithLogitsLoss().to(device)
        # self.tag_criterion = nn.BCEWithLogitsLoss().to(device)

    def load_checkpoint(self, model_dir):
        models_path = utils.read_newest_model(model_dir)
        if len(models_path) == 0:
            return None, None
        models_path.sort()
        new_model_path = os.path.join(model_dump_path, models_path[-1])
        if torch.cuda.is_available():
            checkpoint = torch.load(new_model_path)
        else:
            checkpoint = torch.load(
                new_model_path,
                map_location='cuda' if torch.cuda.is_available() else 'cpu')
        return checkpoint, new_model_path

    def train(self):
        iteration = -1
        label = Variable(torch.FloatTensor(batch_size, 1)).to(device)
        logging.info('Current epoch: {}. Max epoch: {}.'.format(
            self.epoch, max_epoch))
        while self.epoch <= max_epoch:
            msg = {}
            adjust_learning_rate(self.optimizer_G, iteration)
            adjust_learning_rate(self.optimizer_D, iteration)
            for i, (food_tag, food_img) in enumerate(self.data_loader):
                iteration += 1
                if food_img.shape[0] != batch_size:
                    logging.warn('Batch size not satisfied. Ignoring.')
                    continue
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['epoch'] = int(self.epoch)
                        msg['step'] = int(i)
                        msg['iteration'] = iteration

                food_img = Variable(food_img).to(device)
                # 0. training assistant D
                self.a_D.zero_grad()
                a_D_feat = self.a_D(food_img)

                # 1. Training D
                # 1.1. use really image for discriminating
                self.D.zero_grad()
                label_p = self.D(food_img)
                label.data.fill_(1.0)

                # 1.2. real image's loss
                # real_label_loss = self.label_criterion(label_p, label)
                real_label_loss = F.binary_cross_entropy(label_p, label)
                real_loss_sum = real_label_loss
                real_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['discriminator real loss'] = float(real_loss_sum)

                # 1.3. use fake image for discriminating
                g_noise, fake_tag = utils.fake_generator(
                    batch_size, noise_size, device)
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat).detach()
                fake_label_p = self.D(fake_img)
                label.data.fill_(.0)

                # 1.4. fake image's loss
                # fake_label_loss = self.label_criterion(fake_label_p, label)
                fake_label_loss = F.binary_cross_entropy(fake_label_p, label)
                # TODO:
                fake_loss_sum = fake_label_loss
                fake_loss_sum.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        print('predicted fake label: {}'.format(fake_label_p))
                        msg['discriminator fake loss'] = float(fake_loss_sum)

                # 1.6. update optimizer
                self.optimizer_D.step()

                # 2. Training G
                # 2.1. generate fake image
                self.G.zero_grad()
                g_noise, fake_tag = utils.fake_generator(
                    batch_size, noise_size, device)
                fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                fake_img = self.G(fake_feat)
                fake_label_p = self.D(fake_img)
                label.data.fill_(1.0)

                a_D_feat = self.a_D(fake_img)
                feat_loss = F.binary_cross_entropy(a_D_feat, fake_tag)

                # 2.2. calc loss
                # label_loss_g = self.label_criterion(fake_label_p, label)
                label_loss_g = F.binary_cross_entropy(fake_label_p, label)
                loss_g = label_loss_g
                loss_g.backward()
                if verbose:
                    if iteration % verbose_T == 0:
                        msg['generator loss'] = float(loss_g)

                # 2.2. update optimizer
                self.optimizer_G.step()

                if verbose:
                    if iteration % verbose_T == 0:
                        logger.info(
                            '------------------------------------------')
                        for key in msg.keys():
                            logger.info('{} : {}'.format(key, msg[key]))
                # save intermediate file
                if iteration % 10000 == 0:
                    torch.save(
                        {
                            'epoch': self.epoch,
                            'D': self.D.state_dict(),
                            'G': self.G.state_dict(),
                            'optimizer_D': self.optimizer_D.state_dict(),
                            'optimizer_G': self.optimizer_G.state_dict(),
                        },
                        '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                      str(iteration).zfill(8)))
                    logger.info('Checkpoint saved in: {}'.format(
                        '{}/checkpoint_{}.tar'.format(
                            model_dump_path,
                            str(iteration).zfill(8))))

                if iteration % verbose_T == 0:
                    vutils.save_image(
                        food_img.data.view(batch_size, 3, food_img.size(2),
                                           food_img.size(3)),
                        os.path.join(
                            tmp_path, 'real_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    g_noise, fake_tag = utils.fake_generator(
                        batch_size, noise_size, device)
                    fake_feat = torch.cat([g_noise, fake_tag], dim=1)
                    fake_img = self.G(fake_feat)
                    vutils.save_image(
                        fake_img.data.view(batch_size, 3, food_img.size(2),
                                           food_img.size(3)),
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8))))
                    logger.info('Saved intermediate file in {}'.format(
                        os.path.join(
                            tmp_path, 'fake_image_{}.png'.format(
                                str(iteration).zfill(8)))))
            # dump checkpoint
            torch.save(
                {
                    'epoch': self.epoch,
                    'D': self.D.state_dict(),
                    'G': self.G.state_dict(),
                    'optimizer_D': self.optimizer_D.state_dict(),
                    'optimizer_G': self.optimizer_G.state_dict(),
                }, '{}/checkpoint_{}.tar'.format(model_dump_path,
                                                 str(self.epoch).zfill(4)))
            logger.info('Checkpoint saved in: {}'.format(
                '{}/checkpoint_{}.tar'.format(model_dump_path,
                                              str(self.epoch).zfill(4))))
            self.epoch += 1
class Trainer():
    def __init__(self, opt):
        self.device = torch.device('cuda')
        self.opt = opt
        self.G = Generator(self.opt['network_G']).to(self.device)
        util.init_weights(self.G, init_type='kaiming', scale=0.1)
        if self.opt['path']['pretrain_G']:
            self.G.load_state_dict(torch.load(self.opt['path']['pretrain_G']),
                                   strict=True)
        self.D = Discriminator(self.opt['network_D']).to(self.device)
        util.init_weights(self.D, init_type='kaiming', scale=1)
        self.FE = VGGFeatureExtractor().to(self.device)
        self.G.train()
        self.D.train()
        self.FE.eval()

        self.log_dict = OrderedDict()

        self.optim_params = [
            v for k, v in self.G.named_parameters() if v.requires_grad
        ]
        self.opt_G = torch.optim.Adam(self.optim_params,
                                      lr=self.opt['train']['lr_G'],
                                      betas=(self.opt['train']['b1_G'],
                                             self.opt['train']['b2_G']))
        self.opt_D = torch.optim.Adam(self.D.parameters(),
                                      lr=self.opt['train']['lr_D'],
                                      betas=(self.opt['train']['b1_D'],
                                             self.opt['train']['b2_D']))

        self.optimizers = [self.opt_G, self.opt_D]
        self.schedulers = [
            lr_scheduler.MultiStepLR(optimizer, self.opt['train']['lr_steps'],
                                     self.opt['train']['lr_gamma'])
            for optimizer in self.optimizers
        ]

    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()

    def get_current_log(self):
        return self.log_dict

    def get_current_learning_rate(self):
        return self.schedulers[0].get_lr()[0]

    def load_model(self, step, strict=True):
        self.G.load_state_dict(torch.load(
            f"{self.opt['path']['checkpoints']['models']}/{step}_G.pth"),
                               strict=strict)
        self.D.load_state_dict(torch.load(
            f"{self.opt['path']['checkpoints']['models']}/{step}_D.pth"),
                               strict=strict)

    def resume_training(self, resume_state):
        '''Resume the optimizers and schedulers for training'''

        resume_optimizers = resume_state['optimizers']
        resume_schedulers = resume_state['schedulers']
        assert len(resume_optimizers) == len(
            self.optimizers), 'Wrong lengths of optimizers'
        assert len(resume_schedulers) == len(
            self.schedulers), 'Wrong lengths of schedulers'
        for i, o in enumerate(resume_optimizers):
            self.optimizers[i].load_state_dict(o)
        for i, s in enumerate(resume_schedulers):
            self.schedulers[i].load_state_dict(s)

    def save_network(self, network, network_label, iter_step):

        util.mkdir(self.opt['path']['checkpoints']['models'])
        save_filename = '{}_{}.pth'.format(iter_step, network_label)
        save_path = os.path.join(self.opt['path']['checkpoints']['models'],
                                 save_filename)

        if isinstance(network, nn.DataParallel):
            network = network.module
        state_dict = network.state_dict()
        for key, param in state_dict.items():
            state_dict[key] = param.cpu()
        torch.save(state_dict, save_path)

    def save_model(self, epoch, current_step):
        self.save_network(self.G, 'G', current_step)
        self.save_network(self.D, 'D', current_step)
        self.save_training_state(epoch, current_step)

    def save_training_state(self, epoch, iter_step):
        '''Saves training state during training, which will be used for resuming'''
        state = {
            'epoch': epoch,
            'iter': iter_step,
            'schedulers': [],
            'optimizers': []
        }
        for s in self.schedulers:
            state['schedulers'].append(s.state_dict())
        for o in self.optimizers:
            state['optimizers'].append(o.state_dict())
        save_filename = '{}.state'.format(iter_step)
        util.mkdir(self.opt['path']['checkpoints']['states'])
        save_path = os.path.join(self.opt['path']['checkpoints']['states'],
                                 save_filename)
        torch.save(state, save_path)

    def train(self, train_batch, step):

        self.lr = train_batch['LR'].to(self.device)
        self.hr = train_batch['HR'].to(self.device)

        for p in self.D.parameters():
            p.requires_grad = False

        self.opt_G.zero_grad()

        self.sr = self.G(self.lr)

        l_g_total = 0
        # pixel loss
        l_g_pix = self.opt['train']['wt_pix'] * cri_pix(self.sr, self.hr)
        l_g_total += l_g_pix

        # feature loss
        real_fea = self.FE(self.hr).detach()
        fake_fea = self.FE(self.sr)
        l_g_fea = self.opt['train']['wt_fea'] * cri_fea(fake_fea, real_fea)
        l_g_total += l_g_fea

        # ragan loss

        pred_g_fake = self.D(self.sr)
        pred_d_real = self.D(self.hr).detach()

        l_g_gan = self.opt['train']['wt_gan'] * (
            cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
            cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
        l_g_total += l_g_gan

        l_g_total.backward()
        self.opt_G.step()

        # D
        for p in self.D.parameters():
            p.requires_grad = True

        self.opt_D.zero_grad()

        l_d_total = 0

        pred_d_real = self.D(self.hr)
        pred_d_fake = self.D(self.sr.detach())  # detach to avoid BP to G

        l_d_real = cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
        l_d_fake = cri_gan(pred_d_fake - torch.mean(pred_d_real), False)

        l_d_total = (l_d_real + l_d_fake) / 2

        l_d_total.backward()
        self.opt_D.step()

        # set log
        # G
        self.log_dict['l_g_pix'] = l_g_pix.item()
        self.log_dict['l_g_fea'] = l_g_fea.item()
        self.log_dict['l_g_gan'] = l_g_gan.item()

        # D
        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()

        # D outputs
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def validate(self, val_batch, current_step):
        avg_psnr = 0.0
        avg_ssim = 0.0
        idx = 0
        for _, val_data in enumerate(val_batch):
            idx += 1
            img_name = os.path.splitext(
                os.path.basename(val_data['LR_path'][0]))[0]
            img_dir = os.path.join(
                self.opt['path']['checkpoints']['val_image_dir'], img_name)
            util.mkdir(img_dir)

            self.val_lr = val_data['LR'].to(self.device)
            self.val_hr = val_data['HR'].to(self.device)

            self.G.eval()
            with torch.no_grad():
                self.val_sr = self.G(self.val_lr)
            self.G.train()

            val_LR = self.val_lr.detach()[0].float().cpu()
            val_SR = self.val_sr.detach()[0].float().cpu()
            val_HR = self.val_hr.detach()[0].float().cpu()

            sr_img = util.tensor2img(val_SR)  # uint8
            gt_img = util.tensor2img(val_HR)  # uint8

            # Save SR images for reference
            save_img_path = os.path.join(
                img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
            cv2.imwrite(save_img_path, sr_img)

            # calculate PSNR
            crop_size = 4
            gt_img = gt_img / 255.
            sr_img = sr_img / 255.
            cropped_sr_img = sr_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            cropped_gt_img = gt_img[crop_size:-crop_size,
                                    crop_size:-crop_size, :]
            avg_psnr += PSNR(cropped_sr_img * 255, cropped_gt_img * 255)
            avg_ssim += SSIM(cropped_sr_img * 255, cropped_gt_img * 255)

        avg_psnr = avg_psnr / idx
        avg_ssim = avg_ssim / idx
        return avg_psnr, avg_ssim
Beispiel #16
0
    train_dataset = TrainDataset(
        root=args.train_data_path,
        scale_factor=args.scale_factor,
        hr_size=args.hr_size,
        random_crop_size=args.random_crop_size,
    )
    test_dataset = TestDataset(root=args.test_data_path,
                               scale_factor=args.scale_factor,
                               hr_size=args.hr_size)
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64)

    gen_net = Generator(num_res_blocks=args.gen_res_blocks,
                        upscale_factor=args.scale_factor).to(device)
    dis_net = Discriminator(hr_size=args.random_crop_size,
                            sigmoid=not args.no_sigmoid).to(device)
    print(f"Generator number of parameters: {count_parameters(gen_net)}")
    print(f"Discriminator number of parameters: {count_parameters(dis_net)}")

    gen_path = args.from_pretrained_gen
    if gen_path:
        gen_net.load_state_dict(torch.load(gen_path))
    dis_path = args.from_pretrained_dis
    if dis_path:
        dis_net.load_state_dict(torch.load(dis_path))

    perceptual_loss = PerceptualLoss(device=device)
    mse_loss = nn.MSELoss()
    beta1 = 0.9

    opt_gen = optim.Adam(gen_net.parameters(),
Beispiel #17
0
    run_name = 'correlation-GAN_{}'.format(config.version)
    wandb.init(name=run_name,
               dir=config.checkpoint_dir,
               notes=config.description)
    wandb.config.update(config.__dict__)

    device = torch.device('cuda')

    use_dropout = [True, True, False]
    drop_prob = [0.5, 0.5, 0.5]
    use_ac_func = [True, True, False]
    activation = 'relu'
    latent_dim = 10

    gen_fc_layers = [latent_dim, 16, 32, 2]
    generator = Generator(gen_fc_layers, use_dropout, drop_prob, use_ac_func,
                          activation).to(device)

    disc_fc_layers = [2, 32, 16, 1]
    discriminator = Discriminator(disc_fc_layers, use_dropout, drop_prob,
                                  use_ac_func, activation).to(device)

    wandb.watch([generator, discriminator])

    g_optimizer = Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
    d_optimizer = Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))

    wgan_gp = WGAN_GP(config, generator, discriminator, g_optimizer,
                      d_optimizer, latent_shape)
    wgan_gp.train(dataloader, 200)