Example #1
0
class Hidden:
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.discriminator = Discriminator(configuration).to(device)
        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))

    def train_on_batch(self, batch: list):
        """
        Trains the network on a single batch consisting of images and messages
        :param batch: batch of training data, in the form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        images, messages = batch

        batch_size = images.shape[0]
        self.encoder_decoder.train()
        self.discriminator.train()
        with torch.enable_grad():
            # ---------------- Train the discriminator -----------------------------
            self.optimizer_discrim.zero_grad()
            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            # d_loss_on_cover.backward()

            # train on fake
            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_on_encoded = self.discriminator(encoded_images.detach())
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            # d_loss_on_encoded.backward()
            # self.optimizer_discrim.step()

            # --------------Train the generator (encoder-decoder) ---------------------
            self.optimizer_enc_dec.zero_grad()
            # target label for encoded images should be 'cover', because we want to fool the discriminator
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

            g_loss.backward()
            self.optimizer_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        """
        Runs validation on a single batch of data consisting of images and messages
        :param batch: batch of validation data, in form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        # if TensorboardX logging is enabled, save some of the tensors.
        if self.tb_logger is not None:
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            self.tb_logger.add_tensor('weights/encoder_out',
                                      encoder_final.weight)
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            self.tb_logger.add_tensor('weights/decoder_out',
                                      decoder_final.weight)
            discrim_final = self.discriminator._modules['linear']
            self.tb_logger.add_tensor('weights/discrim_out',
                                      discrim_final.weight)

        images, messages = batch

        batch_size = images.shape[0]

        self.encoder_decoder.eval()
        self.discriminator.eval()
        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)

            d_on_encoded = self.discriminator(encoded_images)
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss is None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def to_stirng(self):
        return '{}\n{}'.format(str(self.encoder_decoder),
                               str(self.discriminator))
Example #2
0
class HiDDen(object):
    def __init__(self, config: HiDDenConfiguration, device: torch.device):
        self.enc_dec = EncoderDecoder(config).to(device)
        self.discr = Discriminator(config).to(device)
        self.opt_enc_dec = torch.optim.Adam(self.enc_dec.parameters())
        self.opt_discr = torch.optim.Adam(self.discr.parameters())

        self.config = config
        self.device = device
        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        self.cover_label = 1
        self.encod_label = 0

    def train_on_batch(self, batch: list):
        '''
        Trains the network on a single batch consistring images and messages
        '''
        images, messages = batch
        batch_size = images.shape[0]
        self.enc_dec.train()
        self.discr.train()

        with torch.enable_grad():
            # ---------- Train the discriminator----------
            self.opt_discr.zero_grad()

            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images.detach())
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encod.backward()
            self.opt_discr.step()

            #---------- Train the generator----------
            self.opt_enc_dec.zero_grad()

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.opt_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy())) \
                      / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-error': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        '''Run validation on a batch consist of [images, messages]'''
        images, messages = batch
        batch_size = images.shape[0]

        self.enc_dec.eval()
        self.discr.eval()

        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images)
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy()))\
                     / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-err': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_enced_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def to_stirng(self):
        return f'{str(self.enc_dec)}\n{str(self.discr)}'
Example #3
0
def main():
    ## load std models
    # policy_log_std = torch.load('./model_pkl/policy_net_action_std_model_1.pkl')
    # transition_log_std = torch.load('./model_pkl/transition_net_state_std_model_1.pkl')

    # load expert data
    print(args.data_set_path)
    dataset = ExpertDataSet(args.data_set_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=args.expert_batch_size,
                                  shuffle=True,
                                  num_workers=0)
    # define actor/critic/discriminator net and optimizer
    policy = Policy(onehot_action_sections,
                    onehot_state_sections,
                    state_0=dataset.state)
    value = Value()
    discriminator = Discriminator()
    optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr)
    optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.discrim_lr)
    discriminator_criterion = nn.BCELoss()
    if write_scalar:
        writer = SummaryWriter(log_dir='runs/' + model_name)

    # load net  models
    if load_model:
        discriminator.load_state_dict(
            torch.load('./model_pkl/Discriminator_model_' + model_name +
                       '.pkl'))
        policy.transition_net.load_state_dict(
            torch.load('./model_pkl/Transition_model_' + model_name + '.pkl'))
        policy.policy_net.load_state_dict(
            torch.load('./model_pkl/Policy_model_' + model_name + '.pkl'))
        value.load_state_dict(
            torch.load('./model_pkl/Value_model_' + model_name + '.pkl'))

        policy.policy_net_action_std = torch.load(
            './model_pkl/Policy_net_action_std_model_' + model_name + '.pkl')
        policy.transition_net_state_std = torch.load(
            './model_pkl/Transition_net_state_std_model_' + model_name +
            '.pkl')
    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        policy.train()
        value.train()
        discriminator.train()
        start_time = time.time()
        memory, n_trajs = policy.collect_samples(
            batch_size=args.sample_batch_size)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        onehot_state = torch.cat(batch.onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_state = torch.cat(batch.multihot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_state = torch.cat(batch.continuous_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()

        onehot_action = torch.cat(batch.onehot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_action = torch.cat(batch.multihot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_action = torch.cat(batch.continuous_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_onehot_state = torch.cat(batch.next_onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_multihot_state = torch.cat(batch.next_multihot_state,
                                        dim=1).reshape(
                                            n_trajs * args.sample_traj_length,
                                            -1).detach()
        next_continuous_state = torch.cat(
            batch.next_continuous_state,
            dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()

        old_log_prob = torch.cat(batch.old_log_prob, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        mask = torch.cat(batch.mask,
                         dim=1).reshape(n_trajs * args.sample_traj_length,
                                        -1).detach()
        gen_state = torch.cat((onehot_state, multihot_state, continuous_state),
                              dim=-1)
        gen_action = torch.cat(
            (onehot_action, multihot_action, continuous_action), dim=-1)
        if ep % 1 == 0:
            # if (d_slow_flag and ep % 50 == 0) or (not d_slow_flag and ep % 1 == 0):
            d_loss = torch.empty(0, device=device)
            p_loss = torch.empty(0, device=device)
            v_loss = torch.empty(0, device=device)
            gen_r = torch.empty(0, device=device)
            expert_r = torch.empty(0, device=device)
            for expert_state_batch, expert_action_batch in data_loader:
                noise1 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_state.shape,
                                      device=device)
                noise2 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_action.shape,
                                      device=device)
                noise3 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_state_batch.shape,
                                      device=device)
                noise4 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_action_batch.shape,
                                      device=device)
                gen_r = discriminator(gen_state + noise1, gen_action + noise2)
                expert_r = discriminator(
                    expert_state_batch.to(device) + noise3,
                    expert_action_batch.to(device) + noise4)

                # gen_r = discriminator(gen_state, gen_action)
                # expert_r = discriminator(expert_state_batch.to(device), expert_action_batch.to(device))
                optimizer_discriminator.zero_grad()
                d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \
                            discriminator_criterion(expert_r,torch.ones(expert_r.shape, device=device))
                variance = 0.5 * torch.var(gen_r.to(device)) + 0.5 * torch.var(
                    expert_r.to(device))
                total_d_loss = d_loss - 10 * variance
                d_loss.backward()
                # total_d_loss.backward()
                optimizer_discriminator.step()
            if write_scalar:
                writer.add_scalar('d_loss', d_loss, ep)
                writer.add_scalar('total_d_loss', total_d_loss, ep)
                writer.add_scalar('variance', 10 * variance, ep)
        if ep % 1 == 0:
            # update PPO
            noise1 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_state.shape,
                                  device=device)
            noise2 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_action.shape,
                                  device=device)
            gen_r = discriminator(gen_state + noise1, gen_action + noise2)
            #if gen_r.mean().item() < 0.1:
            #    d_stop = True
            #if d_stop and gen_r.mean()
            optimize_iter_num = int(
                math.ceil(onehot_state.shape[0] / args.ppo_mini_batch_size))
            # gen_r = -(1 - gen_r + 1e-10).log()
            for ppo_ep in range(args.ppo_optim_epoch):
                for i in range(optimize_iter_num):
                    num += 1
                    index = slice(
                        i * args.ppo_mini_batch_size,
                        min((i + 1) * args.ppo_mini_batch_size,
                            onehot_state.shape[0]))
                    onehot_state_batch, multihot_state_batch, continuous_state_batch, onehot_action_batch, multihot_action_batch, continuous_action_batch, \
                    old_log_prob_batch, mask_batch, next_onehot_state_batch, next_multihot_state_batch, next_continuous_state_batch, gen_r_batch = \
                        onehot_state[index], multihot_state[index], continuous_state[index], onehot_action[index], multihot_action[index], continuous_action[index], \
                        old_log_prob[index], mask[index], next_onehot_state[index], next_multihot_state[index], next_continuous_state[index], gen_r[
                            index]
                    v_loss, p_loss = ppo_step(
                        policy, value, optimizer_policy, optimizer_value,
                        onehot_state_batch, multihot_state_batch,
                        continuous_state_batch, onehot_action_batch,
                        multihot_action_batch, continuous_action_batch,
                        next_onehot_state_batch, next_multihot_state_batch,
                        next_continuous_state_batch, gen_r_batch,
                        old_log_prob_batch, mask_batch, args.ppo_clip_epsilon)
                    if write_scalar:
                        writer.add_scalar('p_loss', p_loss, ep)
                        writer.add_scalar('v_loss', v_loss, ep)
        policy.eval()
        value.eval()
        discriminator.eval()
        noise1 = torch.normal(0,
                              args.noise_std,
                              size=gen_state.shape,
                              device=device)
        noise2 = torch.normal(0,
                              args.noise_std,
                              size=gen_action.shape,
                              device=device)
        gen_r = discriminator(gen_state + noise1, gen_action + noise2)
        expert_r = discriminator(
            expert_state_batch.to(device) + noise3,
            expert_action_batch.to(device) + noise4)
        gen_r_noise = gen_r.mean().item()
        expert_r_noise = expert_r.mean().item()
        gen_r = discriminator(gen_state, gen_action)
        expert_r = discriminator(expert_state_batch.to(device),
                                 expert_action_batch.to(device))
        if write_scalar:
            writer.add_scalar('gen_r', gen_r.mean(), ep)
            writer.add_scalar('expert_r', expert_r.mean(), ep)
            writer.add_scalar('gen_r_noise', gen_r_noise, ep)
            writer.add_scalar('expert_r_noise', expert_r_noise, ep)
        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('gen_r_noise', gen_r_noise)
        print('expert_r_noise', expert_r_noise)
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())
        print('d_loss', d_loss.item())
        # save models
        if model_name is not None:
            torch.save(
                discriminator.state_dict(),
                './model_pkl/Discriminator_model_' + model_name + '.pkl')
            torch.save(policy.transition_net.state_dict(),
                       './model_pkl/Transition_model_' + model_name + '.pkl')
            torch.save(policy.policy_net.state_dict(),
                       './model_pkl/Policy_model_' + model_name + '.pkl')
            torch.save(
                policy.policy_net_action_std,
                './model_pkl/Policy_net_action_std_model_' + model_name +
                '.pkl')
            torch.save(
                policy.transition_net_state_std,
                './model_pkl/Transition_net_state_std_model_' + model_name +
                '.pkl')
            torch.save(value.state_dict(),
                       './model_pkl/Value_model_' + model_name + '.pkl')
        memory.clear_memory()
Example #4
0
class Trainer(nn.Module):
    def __init__(self, model_dir, g_optimizer, d_optimizer, lr, warmup,
                 max_iters):
        super().__init__()
        self.model_dir = model_dir
        if not os.path.exists(f'checkpoints/{model_dir}'):
            os.makedirs(f'checkpoints/{model_dir}')
        self.logs_dir = f'checkpoints/{model_dir}/logs'
        if not os.path.exists(self.logs_dir):
            os.makedirs(self.logs_dir)
        self.writer = SummaryWriter(self.logs_dir)

        self.arcface = ArcFaceNet(50, 0.6, 'ir_se').cuda()
        self.arcface.eval()
        self.arcface.load_state_dict(torch.load(
            'checkpoints/model_ir_se50.pth', map_location='cuda'),
                                     strict=False)

        self.mobiface = MobileFaceNet(512).cuda()
        self.mobiface.eval()
        self.mobiface.load_state_dict(torch.load(
            'checkpoints/mobilefacenet.pth', map_location='cuda'),
                                      strict=False)

        self.generator = Generator().cuda()
        self.discriminator = Discriminator().cuda()

        self.adversarial_weight = 1
        self.src_id_weight = 5
        self.tgt_id_weight = 1
        self.attributes_weight = 10
        self.reconstruction_weight = 10

        self.lr = lr
        self.warmup = warmup
        self.g_optimizer = g_optimizer(self.generator.parameters(),
                                       lr=lr,
                                       betas=(0, 0.999))
        self.d_optimizer = d_optimizer(self.discriminator.parameters(),
                                       lr=lr,
                                       betas=(0, 0.999))

        self.generator, self.g_optimizer = amp.initialize(self.generator,
                                                          self.g_optimizer,
                                                          opt_level="O1")
        self.discriminator, self.d_optimizer = amp.initialize(
            self.discriminator, self.d_optimizer, opt_level="O1")

        self._iter = nn.Parameter(torch.tensor(1), requires_grad=False)
        self.max_iters = max_iters

        if torch.cuda.is_available():
            self.cuda()

    @property
    def iter(self):
        return self._iter.item()

    @property
    def device(self):
        return next(self.parameters()).device

    def adapt(self, args):
        device = self.device
        return [arg.to(device) for arg in args]

    def train_loop(self, dataloaders, eval_every, generate_every, save_every):
        for batch in tqdm(dataloaders['train']):
            torch.Tensor.add_(self._iter, 1)
            # generator step
            # if self.iter % 2 == 0:
            # self.adjust_lr(self.g_optimizer)
            g_losses = self.g_step(self.adapt(batch))
            g_stats = self.get_opt_stats(self.g_optimizer, type='generator')
            self.write_logs(losses=g_losses, stats=g_stats, type='generator')

            # #discriminator step
            # if self.iter % 2 == 1:
            # self.adjust_lr(self.d_optimizer)
            d_losses = self.d_step(self.adapt(batch))
            d_stats = self.get_opt_stats(self.d_optimizer,
                                         type='discriminator')
            self.write_logs(losses=d_losses,
                            stats=d_stats,
                            type='discriminator')

            if self.iter % eval_every == 0:
                discriminator_acc = self.evaluate_discriminator_accuracy(
                    dataloaders['val'])
                identification_acc = self.evaluate_identification_similarity(
                    dataloaders['val'])
                metrics = {**discriminator_acc, **identification_acc}
                self.write_logs(metrics=metrics)

            if self.iter % generate_every == 0:
                self.generate(*self.adapt(batch))

            if self.iter % save_every == 0:
                self.save_discriminator()
                self.save_generator()

    def g_step(self, batch):
        self.generator.train()
        self.g_optimizer.zero_grad()
        L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator = self.g_loss(
            *batch)
        with amp.scale_loss(L_generator, self.g_optimizer) as scaled_loss:
            scaled_loss.backward()
        self.g_optimizer.step()

        losses = {
            'adv': L_adv.item(),
            'src_id': L_src_id.item(),
            'tgt_id': L_tgt_id.item(),
            'attributes': L_attr.item(),
            'reconstruction': L_rec.item(),
            'total_loss': L_generator.item()
        }
        return losses

    def d_step(self, batch):
        self.discriminator.train()
        self.d_optimizer.zero_grad()
        L_fake, L_real, L_discriminator = self.d_loss(*batch)
        with amp.scale_loss(L_discriminator, self.d_optimizer) as scaled_loss:
            scaled_loss.backward()
        self.d_optimizer.step()

        losses = {
            'hinge_fake': L_fake.item(),
            'hinge_real': L_real.item(),
            'total_loss': L_discriminator.item()
        }
        return losses

    def g_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
            tgt_embed = self.arcface(
                F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))

        Y_hat, Xt_attr = self.generator(Xt, src_embed, return_attributes=True)

        Di = self.discriminator(Y_hat)

        L_adv = 0
        for di in Di:
            L_adv += hinge_loss(di[0], True)

        fake_embed = self.arcface(
            F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112],
                          mode='bilinear',
                          align_corners=True))
        L_src_id = (
            1 - torch.cosine_similarity(src_embed, fake_embed, dim=1)).mean()
        L_tgt_id = (
            1 - torch.cosine_similarity(tgt_embed, fake_embed, dim=1)).mean()

        batch_size = Xs.shape[0]
        Y_hat_attr = self.generator.get_attr(Y_hat)
        L_attr = 0
        for i in range(len(Xt_attr)):
            L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_hat_attr[i],
                                           2).reshape(batch_size, -1),
                                 dim=1).mean()
        L_attr /= 2.0

        L_rec = torch.sum(
            0.5 * torch.mean(torch.pow(Y_hat - Xt, 2).reshape(batch_size, -1),
                             dim=1) * same_person) / (same_person.sum() + 1e-6)
        L_generator = (self.adversarial_weight *
                       L_adv) + (self.src_id_weight * L_src_id) + (
                           self.tgt_id_weight *
                           L_tgt_id) + (self.attributes_weight * L_attr) + (
                               self.reconstruction_weight * L_rec)
        return L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator

    def d_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
        Y_hat = self.generator(Xt, src_embed, return_attributes=False)

        fake_D = self.discriminator(Y_hat.detach())
        L_fake = 0
        for di in fake_D:
            L_fake += hinge_loss(di[0], False)
        real_D = self.discriminator(Xs)
        L_real = 0
        for di in real_D:
            L_real += hinge_loss(di[0], True)

        L_discriminator = 0.5 * (L_real + L_fake)
        return L_fake, L_real, L_discriminator

    def evaluate_discriminator_accuracy(self, val_dataloader):
        real_acc = 0
        fake_acc = 0
        self.generator.eval()
        self.discriminator.eval()
        for batch in tqdm(val_dataloader):
            Xs, Xt, _ = self.adapt(batch)

            with torch.no_grad():
                embed = self.arcface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                Y_hat = self.generator(Xt, embed, return_attributes=False)
                fake_D = self.discriminator(Y_hat)
                real_D = self.discriminator(Xs)

            fake_multiscale_acc = 0
            for di in fake_D:
                fake_multiscale_acc += torch.mean((di[0] < 0).float())
            fake_acc += fake_multiscale_acc / len(fake_D)

            real_multiscale_acc = 0
            for di in real_D:
                real_multiscale_acc += torch.mean((di[0] > 0).float())
            real_acc += real_multiscale_acc / len(real_D)

        self.generator.train()
        self.discriminator.train()

        metrics = {
            'fake_acc': 100 * (fake_acc / len(val_dataloader)).item(),
            'real_acc': 100 * (real_acc / len(val_dataloader)).item()
        }
        return metrics

    def evaluate_identification_similarity(self, val_dataloader):
        src_id_sim = 0
        tgt_id_sim = 0
        self.generator.eval()
        for batch in tqdm(val_dataloader):
            Xs, Xt, _ = self.adapt(batch)
            with torch.no_grad():
                src_embed = self.arcface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                Y_hat = self.generator(Xt, src_embed, return_attributes=False)

                src_embed = self.mobiface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                tgt_embed = self.mobiface(
                    F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                fake_embed = self.mobiface(
                    F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))

            src_id_sim += (torch.cosine_similarity(src_embed,
                                                   fake_embed,
                                                   dim=1)).float().mean()
            tgt_id_sim += (torch.cosine_similarity(tgt_embed,
                                                   fake_embed,
                                                   dim=1)).float().mean()

        self.generator.train()

        metrics = {
            'src_similarity': 100 * (src_id_sim / len(val_dataloader)).item(),
            'tgt_similarity': 100 * (tgt_id_sim / len(val_dataloader)).item()
        }
        return metrics

    def generate(self, Xs, Xt, same_person):
        def get_grid_image(X):
            X = X[:8]
            X = torchvision.utils.make_grid(X.detach().cpu(), nrow=X.shape[0])
            X = (X * 0.5 + 0.5) * 255
            return X

        def make_image(Xs, Xt, Y_hat):
            Xs = get_grid_image(Xs)
            Xt = get_grid_image(Xt)
            Y_hat = get_grid_image(Y_hat)
            return torch.cat((Xs, Xt, Y_hat), dim=1).numpy()

        with torch.no_grad():
            embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
            self.generator.eval()
            Y_hat = self.generator(Xt, embed, return_attributes=False)
            self.generator.train()

        image = make_image(Xs, Xt, Y_hat)
        if not os.path.exists(f'results/{self.model_dir}'):
            os.makedirs(f'results/{self.model_dir}')
        cv2.imwrite(f'results/{self.model_dir}/{self.iter}.jpg',
                    image.transpose([1, 2, 0]))

    def get_opt_stats(self, optimizer, type=''):
        stats = {f'{type}_lr': optimizer.param_groups[0]['lr']}
        return stats

    def adjust_lr(self, optimizer):
        if self.iter <= self.warmup:
            lr = self.lr * self.iter / self.warmup
        else:
            lr = self.lr * (1 + cos(pi * (self.iter - self.warmup) /
                                    (self.max_iters - self.warmup))) / 2

        for group in optimizer.param_groups:
            group['lr'] = lr
        return lr

    def write_logs(self, losses=None, metrics=None, stats=None, type='loss'):
        if losses:
            for name, value in losses.items():
                self.writer.add_scalar(f'{type}/{name}', value, self.iter)
        if metrics:
            for name, value in metrics.items():
                self.writer.add_scalar(f'metric/{name}', value, self.iter)
        if stats:
            for name, value in stats.items():
                self.writer.add_scalar(f'stats/{name}', value, self.iter)

    def save_generator(self, max_checkpoints=100):
        checkpoints = glob.glob(f'{self.model_dir}/*.pt')
        if len(checkpoints) > max_checkpoints:
            os.remove(checkpoints[-1])
        with open(f'checkpoints/{self.model_dir}/generator_{self.iter}.pt',
                  'wb') as f:
            torch.save(self.generator.state_dict(), f)

    def save_discriminator(self, max_checkpoints=100):
        checkpoints = glob.glob(f'{self.model_dir}/*.pt')
        if len(checkpoints) > max_checkpoints:
            os.remove(checkpoints[-1])
        with open(f'checkpoints/{self.model_dir}/discriminator_{self.iter}.pt',
                  'wb') as f:
            torch.save(self.discriminator.state_dict(), f)

    def load_discriminator(self, path, load_last=True):
        if load_last:
            try:
                checkpoints = glob.glob(f'{path}/discriminator*.pt')
                path = max(checkpoints, key=os.path.getctime)
            except (ValueError):
                print(f'Directory is empty: {path}')

        try:
            self.discriminator.load_state_dict(torch.load(path))
            self.cuda()
        except (FileNotFoundError):
            print(f'No such file: {path}')

    def load_generator(self, path, load_last=True):
        if load_last:
            try:
                checkpoints = glob.glob(f'{path}/generator*.pt')
                path = max(checkpoints, key=os.path.getctime)
            except (ValueError):
                print(f'Directory is empty: {path}')

        try:
            self.generator.load_state_dict(torch.load(path))
            iter_str = ''.join(filter(lambda x: x.isdigit(), path))
            self._iter = nn.Parameter(torch.tensor(int(iter_str)),
                                      requires_grad=False)
            self.cuda()
        except (FileNotFoundError):
            print(f'No such file: {path}')
Example #5
0
    gv2_model_path = "/home/nhli/PycharmProj/ReIDGAN_/params/record-step-12685-model.pkl"
    # trp-fc-new
    dis_model_path = "/home/nhli/PycharmProj/ReIDGAN_/workdir/trp-fc-new/save-dis-4599"
    # adv-0
    # dis_model_path = "/home/nhli/PycharmProj/ReIDGAN_/workdir/adv-0/save-sel_4-49"

    # build graph
    feature_extractor = Inceptionv2()
    fea_ext_dict = torch.load(gv2_model_path)
    fea_ext_dict.pop('classifier.weight')
    fea_ext_dict.pop('classifier.bias')
    fea_ext_dict.pop('criterion2.center_feature')
    fea_ext_dict.pop('criterion2.all_labels')
    feature_extractor.load_state_dict(fea_ext_dict)

    dis = Discriminator()
    # dis = Selector()
    dis.load_state_dict(torch.load(dis_model_path))

    # eval
    dis.eval()
    feature_extractor.eval()

    if is_cuda:
        dis.cuda()
        feature_extractor.cuda()

    dataset_feature = extract_all_feature(batch_size, feature_extractor)

    mAP = mAP(dis, dataset_feature, batch_size)