Exemplo n.º 1
0
def main():

    # set torch and numpy seed for reproducibility
    torch.manual_seed(27)
    np.random.seed(27)

    # tensorboard writer
    writer = SummaryWriter(settings.TENSORBOARD_DIR)
    # makedir snapshot
    makedir(settings.CHECKPOINT_DIR)

    # enable cudnn
    torch.backends.cudnn.enabled = True

    # create segmentor network
    model_G = Segmentor(pretrained=settings.PRETRAINED,
                        num_classes=settings.NUM_CLASSES,
                        modality=settings.MODALITY)

    model_G.train()
    model_G.cuda()

    torch.backends.cudnn.benchmark = True

    # create discriminator network
    model_D = Discriminator(settings.NUM_CLASSES)
    model_D.train()
    model_D.cuda()

    # dataset and dataloader
    dataset = TrainDataset()
    dataloader = data.DataLoader(dataset,
                                 batch_size=settings.BATCH_SIZE,
                                 shuffle=True,
                                 num_workers=settings.NUM_WORKERS,
                                 pin_memory=True,
                                 drop_last=True)

    test_dataset = TestDataset(data_root=settings.DATA_ROOT_VAL,
                               data_list=settings.DATA_LIST_VAL)
    test_dataloader = data.DataLoader(test_dataset,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=settings.NUM_WORKERS,
                                      pin_memory=True)

    # optimizer for generator network (segmentor)
    optim_G = optim.SGD(model_G.optim_parameters(settings.LR),
                        lr=settings.LR,
                        momentum=settings.LR_MOMENTUM,
                        weight_decay=settings.WEIGHT_DECAY)

    # lr scheduler for optimi_G
    lr_lambda_G = lambda epoch: (1 - epoch / settings.EPOCHS
                                 )**settings.LR_POLY_POWER
    lr_scheduler_G = optim.lr_scheduler.LambdaLR(optim_G,
                                                 lr_lambda=lr_lambda_G)

    # optimizer for discriminator network
    optim_D = optim.Adam(model_D.parameters(), settings.LR_D)

    # lr scheduler for optimi_D
    lr_lambda_D = lambda epoch: (1 - epoch / settings.EPOCHS
                                 )**settings.LR_POLY_POWER
    lr_scheduler_D = optim.lr_scheduler.LambdaLR(optim_D,
                                                 lr_lambda=lr_lambda_D)

    # losses
    ce_loss = CrossEntropyLoss2d(
        ignore_index=settings.IGNORE_LABEL)  # to use for segmentor
    bce_loss = BCEWithLogitsLoss2d()  # to use for discriminator

    # upsampling for the network output
    upsample = nn.Upsample(size=(settings.CROP_SIZE, settings.CROP_SIZE),
                           mode='bilinear',
                           align_corners=True)

    # # labels for adversarial training
    # pred_label = 0
    # gt_label = 1

    # load the model to resume training
    last_epoch = -1
    if settings.RESUME_TRAIN:
        checkpoint = torch.load(settings.LAST_CHECKPOINT)

        model_G.load_state_dict(checkpoint['model_G_state_dict'])
        model_G.train()
        model_G.cuda()

        model_D.load_state_dict(checkpoint['model_D_state_dict'])
        model_D.train()
        model_D.cuda()

        optim_G.load_state_dict(checkpoint['optim_G_state_dict'])
        optim_D.load_state_dict(checkpoint['optim_D_state_dict'])

        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G_state_dict'])
        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D_state_dict'])

        last_epoch = checkpoint['epoch']

        # purge the logs after the last_epoch
        writer = SummaryWriter(settings.TENSORBOARD_DIR,
                               purge_step=(last_epoch + 1) * len(dataloader))

    for epoch in range(last_epoch + 1, settings.EPOCHS + 1):

        train_one_epoch(model_G,
                        model_D,
                        optim_G,
                        optim_D,
                        dataloader,
                        test_dataloader,
                        epoch,
                        upsample,
                        ce_loss,
                        bce_loss,
                        writer,
                        print_freq=5,
                        eval_freq=settings.EVAL_FREQ)

        if epoch % settings.CHECKPOINT_FREQ == 0 and epoch != 0:
            save_checkpoint(epoch, model_G, model_D, optim_G, optim_D,
                            lr_scheduler_G, lr_scheduler_D)

        # save the final model
        if epoch >= settings.EPOCHS:
            print('saving the final model')
            save_checkpoint(epoch, model_G, model_D, optim_G, optim_D,
                            lr_scheduler_G, lr_scheduler_D)
            writer.close()

        lr_scheduler_G.step()
        lr_scheduler_D.step()
Exemplo n.º 2
0
class Hidden:
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.discriminator = Discriminator(configuration).to(device)
        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))

    def train_on_batch(self, batch: list):
        """
        Trains the network on a single batch consisting of images and messages
        :param batch: batch of training data, in the form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        images, messages = batch

        batch_size = images.shape[0]
        self.encoder_decoder.train()
        self.discriminator.train()
        with torch.enable_grad():
            # ---------------- Train the discriminator -----------------------------
            self.optimizer_discrim.zero_grad()
            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            # d_loss_on_cover.backward()

            # train on fake
            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_on_encoded = self.discriminator(encoded_images.detach())
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            # d_loss_on_encoded.backward()
            # self.optimizer_discrim.step()

            # --------------Train the generator (encoder-decoder) ---------------------
            self.optimizer_enc_dec.zero_grad()
            # target label for encoded images should be 'cover', because we want to fool the discriminator
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

            g_loss.backward()
            self.optimizer_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        """
        Runs validation on a single batch of data consisting of images and messages
        :param batch: batch of validation data, in form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        # if TensorboardX logging is enabled, save some of the tensors.
        if self.tb_logger is not None:
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            self.tb_logger.add_tensor('weights/encoder_out',
                                      encoder_final.weight)
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            self.tb_logger.add_tensor('weights/decoder_out',
                                      decoder_final.weight)
            discrim_final = self.discriminator._modules['linear']
            self.tb_logger.add_tensor('weights/discrim_out',
                                      discrim_final.weight)

        images, messages = batch

        batch_size = images.shape[0]

        self.encoder_decoder.eval()
        self.discriminator.eval()
        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)

            d_on_encoded = self.discriminator(encoded_images)
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss is None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def to_stirng(self):
        return '{}\n{}'.format(str(self.encoder_decoder),
                               str(self.discriminator))
Exemplo n.º 3
0
dataset = Dataset('c:/DATASETS/celebA/data.txt',
                  'c:/DATASETS/celebA/img_align_celeba', transform)

# logs_idx = len(glob.glob('logs/*'))
# depth_start = 0
# epoch_start = 0
# global_idx = 0

logs_idx = 0
saves = glob.glob(f'logs/{logs_idx}/*.pt')
saves.sort(key=os.path.getmtime)
checkpoint = torch.load(saves[-1])
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.train()
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
discriminator.train()
g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
depth_start = checkpoint['depth']
epoch_start = checkpoint['epoch']
global_idx = checkpoint['global_idx']

writer = tensorboard.SummaryWriter(log_dir=f'logs/{logs_idx}')

for depth, (batch_size,
            epoch_size) in enumerate(tqdm(zip(batch_sizes[depth_start:],
                                              epoch_sizes[depth_start:]),
                                          initial=depth_start,
                                          total=len(epoch_sizes)),
                                     start=depth_start):
    dataloader = DataLoader(dataset,
Exemplo n.º 4
0
class HiDDen(object):
    def __init__(self, config: HiDDenConfiguration, device: torch.device):
        self.enc_dec = EncoderDecoder(config).to(device)
        self.discr = Discriminator(config).to(device)
        self.opt_enc_dec = torch.optim.Adam(self.enc_dec.parameters())
        self.opt_discr = torch.optim.Adam(self.discr.parameters())

        self.config = config
        self.device = device
        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        self.cover_label = 1
        self.encod_label = 0

    def train_on_batch(self, batch: list):
        '''
        Trains the network on a single batch consistring images and messages
        '''
        images, messages = batch
        batch_size = images.shape[0]
        self.enc_dec.train()
        self.discr.train()

        with torch.enable_grad():
            # ---------- Train the discriminator----------
            self.opt_discr.zero_grad()

            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images.detach())
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encod.backward()
            self.opt_discr.step()

            #---------- Train the generator----------
            self.opt_enc_dec.zero_grad()

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.opt_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy())) \
                      / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-error': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        '''Run validation on a batch consist of [images, messages]'''
        images, messages = batch
        batch_size = images.shape[0]

        self.enc_dec.eval()
        self.discr.eval()

        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images)
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy()))\
                     / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-err': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_enced_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def to_stirng(self):
        return f'{str(self.enc_dec)}\n{str(self.discr)}'
Exemplo n.º 5
0
class Trainer(nn.Module):
    def __init__(self, model_dir, g_optimizer, d_optimizer, lr, warmup,
                 max_iters):
        super().__init__()
        self.model_dir = model_dir
        if not os.path.exists(f'checkpoints/{model_dir}'):
            os.makedirs(f'checkpoints/{model_dir}')
        self.logs_dir = f'checkpoints/{model_dir}/logs'
        if not os.path.exists(self.logs_dir):
            os.makedirs(self.logs_dir)
        self.writer = SummaryWriter(self.logs_dir)

        self.arcface = ArcFaceNet(50, 0.6, 'ir_se').cuda()
        self.arcface.eval()
        self.arcface.load_state_dict(torch.load(
            'checkpoints/model_ir_se50.pth', map_location='cuda'),
                                     strict=False)

        self.mobiface = MobileFaceNet(512).cuda()
        self.mobiface.eval()
        self.mobiface.load_state_dict(torch.load(
            'checkpoints/mobilefacenet.pth', map_location='cuda'),
                                      strict=False)

        self.generator = Generator().cuda()
        self.discriminator = Discriminator().cuda()

        self.adversarial_weight = 1
        self.src_id_weight = 5
        self.tgt_id_weight = 1
        self.attributes_weight = 10
        self.reconstruction_weight = 10

        self.lr = lr
        self.warmup = warmup
        self.g_optimizer = g_optimizer(self.generator.parameters(),
                                       lr=lr,
                                       betas=(0, 0.999))
        self.d_optimizer = d_optimizer(self.discriminator.parameters(),
                                       lr=lr,
                                       betas=(0, 0.999))

        self.generator, self.g_optimizer = amp.initialize(self.generator,
                                                          self.g_optimizer,
                                                          opt_level="O1")
        self.discriminator, self.d_optimizer = amp.initialize(
            self.discriminator, self.d_optimizer, opt_level="O1")

        self._iter = nn.Parameter(torch.tensor(1), requires_grad=False)
        self.max_iters = max_iters

        if torch.cuda.is_available():
            self.cuda()

    @property
    def iter(self):
        return self._iter.item()

    @property
    def device(self):
        return next(self.parameters()).device

    def adapt(self, args):
        device = self.device
        return [arg.to(device) for arg in args]

    def train_loop(self, dataloaders, eval_every, generate_every, save_every):
        for batch in tqdm(dataloaders['train']):
            torch.Tensor.add_(self._iter, 1)
            # generator step
            # if self.iter % 2 == 0:
            # self.adjust_lr(self.g_optimizer)
            g_losses = self.g_step(self.adapt(batch))
            g_stats = self.get_opt_stats(self.g_optimizer, type='generator')
            self.write_logs(losses=g_losses, stats=g_stats, type='generator')

            # #discriminator step
            # if self.iter % 2 == 1:
            # self.adjust_lr(self.d_optimizer)
            d_losses = self.d_step(self.adapt(batch))
            d_stats = self.get_opt_stats(self.d_optimizer,
                                         type='discriminator')
            self.write_logs(losses=d_losses,
                            stats=d_stats,
                            type='discriminator')

            if self.iter % eval_every == 0:
                discriminator_acc = self.evaluate_discriminator_accuracy(
                    dataloaders['val'])
                identification_acc = self.evaluate_identification_similarity(
                    dataloaders['val'])
                metrics = {**discriminator_acc, **identification_acc}
                self.write_logs(metrics=metrics)

            if self.iter % generate_every == 0:
                self.generate(*self.adapt(batch))

            if self.iter % save_every == 0:
                self.save_discriminator()
                self.save_generator()

    def g_step(self, batch):
        self.generator.train()
        self.g_optimizer.zero_grad()
        L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator = self.g_loss(
            *batch)
        with amp.scale_loss(L_generator, self.g_optimizer) as scaled_loss:
            scaled_loss.backward()
        self.g_optimizer.step()

        losses = {
            'adv': L_adv.item(),
            'src_id': L_src_id.item(),
            'tgt_id': L_tgt_id.item(),
            'attributes': L_attr.item(),
            'reconstruction': L_rec.item(),
            'total_loss': L_generator.item()
        }
        return losses

    def d_step(self, batch):
        self.discriminator.train()
        self.d_optimizer.zero_grad()
        L_fake, L_real, L_discriminator = self.d_loss(*batch)
        with amp.scale_loss(L_discriminator, self.d_optimizer) as scaled_loss:
            scaled_loss.backward()
        self.d_optimizer.step()

        losses = {
            'hinge_fake': L_fake.item(),
            'hinge_real': L_real.item(),
            'total_loss': L_discriminator.item()
        }
        return losses

    def g_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
            tgt_embed = self.arcface(
                F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))

        Y_hat, Xt_attr = self.generator(Xt, src_embed, return_attributes=True)

        Di = self.discriminator(Y_hat)

        L_adv = 0
        for di in Di:
            L_adv += hinge_loss(di[0], True)

        fake_embed = self.arcface(
            F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112],
                          mode='bilinear',
                          align_corners=True))
        L_src_id = (
            1 - torch.cosine_similarity(src_embed, fake_embed, dim=1)).mean()
        L_tgt_id = (
            1 - torch.cosine_similarity(tgt_embed, fake_embed, dim=1)).mean()

        batch_size = Xs.shape[0]
        Y_hat_attr = self.generator.get_attr(Y_hat)
        L_attr = 0
        for i in range(len(Xt_attr)):
            L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_hat_attr[i],
                                           2).reshape(batch_size, -1),
                                 dim=1).mean()
        L_attr /= 2.0

        L_rec = torch.sum(
            0.5 * torch.mean(torch.pow(Y_hat - Xt, 2).reshape(batch_size, -1),
                             dim=1) * same_person) / (same_person.sum() + 1e-6)
        L_generator = (self.adversarial_weight *
                       L_adv) + (self.src_id_weight * L_src_id) + (
                           self.tgt_id_weight *
                           L_tgt_id) + (self.attributes_weight * L_attr) + (
                               self.reconstruction_weight * L_rec)
        return L_adv, L_src_id, L_tgt_id, L_attr, L_rec, L_generator

    def d_loss(self, Xs, Xt, same_person):
        with torch.no_grad():
            src_embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
        Y_hat = self.generator(Xt, src_embed, return_attributes=False)

        fake_D = self.discriminator(Y_hat.detach())
        L_fake = 0
        for di in fake_D:
            L_fake += hinge_loss(di[0], False)
        real_D = self.discriminator(Xs)
        L_real = 0
        for di in real_D:
            L_real += hinge_loss(di[0], True)

        L_discriminator = 0.5 * (L_real + L_fake)
        return L_fake, L_real, L_discriminator

    def evaluate_discriminator_accuracy(self, val_dataloader):
        real_acc = 0
        fake_acc = 0
        self.generator.eval()
        self.discriminator.eval()
        for batch in tqdm(val_dataloader):
            Xs, Xt, _ = self.adapt(batch)

            with torch.no_grad():
                embed = self.arcface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                Y_hat = self.generator(Xt, embed, return_attributes=False)
                fake_D = self.discriminator(Y_hat)
                real_D = self.discriminator(Xs)

            fake_multiscale_acc = 0
            for di in fake_D:
                fake_multiscale_acc += torch.mean((di[0] < 0).float())
            fake_acc += fake_multiscale_acc / len(fake_D)

            real_multiscale_acc = 0
            for di in real_D:
                real_multiscale_acc += torch.mean((di[0] > 0).float())
            real_acc += real_multiscale_acc / len(real_D)

        self.generator.train()
        self.discriminator.train()

        metrics = {
            'fake_acc': 100 * (fake_acc / len(val_dataloader)).item(),
            'real_acc': 100 * (real_acc / len(val_dataloader)).item()
        }
        return metrics

    def evaluate_identification_similarity(self, val_dataloader):
        src_id_sim = 0
        tgt_id_sim = 0
        self.generator.eval()
        for batch in tqdm(val_dataloader):
            Xs, Xt, _ = self.adapt(batch)
            with torch.no_grad():
                src_embed = self.arcface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                Y_hat = self.generator(Xt, src_embed, return_attributes=False)

                src_embed = self.mobiface(
                    F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                tgt_embed = self.mobiface(
                    F.interpolate(Xt[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))
                fake_embed = self.mobiface(
                    F.interpolate(Y_hat[:, :, 19:237, 19:237], [112, 112],
                                  mode='bilinear',
                                  align_corners=True))

            src_id_sim += (torch.cosine_similarity(src_embed,
                                                   fake_embed,
                                                   dim=1)).float().mean()
            tgt_id_sim += (torch.cosine_similarity(tgt_embed,
                                                   fake_embed,
                                                   dim=1)).float().mean()

        self.generator.train()

        metrics = {
            'src_similarity': 100 * (src_id_sim / len(val_dataloader)).item(),
            'tgt_similarity': 100 * (tgt_id_sim / len(val_dataloader)).item()
        }
        return metrics

    def generate(self, Xs, Xt, same_person):
        def get_grid_image(X):
            X = X[:8]
            X = torchvision.utils.make_grid(X.detach().cpu(), nrow=X.shape[0])
            X = (X * 0.5 + 0.5) * 255
            return X

        def make_image(Xs, Xt, Y_hat):
            Xs = get_grid_image(Xs)
            Xt = get_grid_image(Xt)
            Y_hat = get_grid_image(Y_hat)
            return torch.cat((Xs, Xt, Y_hat), dim=1).numpy()

        with torch.no_grad():
            embed = self.arcface(
                F.interpolate(Xs[:, :, 19:237, 19:237], [112, 112],
                              mode='bilinear',
                              align_corners=True))
            self.generator.eval()
            Y_hat = self.generator(Xt, embed, return_attributes=False)
            self.generator.train()

        image = make_image(Xs, Xt, Y_hat)
        if not os.path.exists(f'results/{self.model_dir}'):
            os.makedirs(f'results/{self.model_dir}')
        cv2.imwrite(f'results/{self.model_dir}/{self.iter}.jpg',
                    image.transpose([1, 2, 0]))

    def get_opt_stats(self, optimizer, type=''):
        stats = {f'{type}_lr': optimizer.param_groups[0]['lr']}
        return stats

    def adjust_lr(self, optimizer):
        if self.iter <= self.warmup:
            lr = self.lr * self.iter / self.warmup
        else:
            lr = self.lr * (1 + cos(pi * (self.iter - self.warmup) /
                                    (self.max_iters - self.warmup))) / 2

        for group in optimizer.param_groups:
            group['lr'] = lr
        return lr

    def write_logs(self, losses=None, metrics=None, stats=None, type='loss'):
        if losses:
            for name, value in losses.items():
                self.writer.add_scalar(f'{type}/{name}', value, self.iter)
        if metrics:
            for name, value in metrics.items():
                self.writer.add_scalar(f'metric/{name}', value, self.iter)
        if stats:
            for name, value in stats.items():
                self.writer.add_scalar(f'stats/{name}', value, self.iter)

    def save_generator(self, max_checkpoints=100):
        checkpoints = glob.glob(f'{self.model_dir}/*.pt')
        if len(checkpoints) > max_checkpoints:
            os.remove(checkpoints[-1])
        with open(f'checkpoints/{self.model_dir}/generator_{self.iter}.pt',
                  'wb') as f:
            torch.save(self.generator.state_dict(), f)

    def save_discriminator(self, max_checkpoints=100):
        checkpoints = glob.glob(f'{self.model_dir}/*.pt')
        if len(checkpoints) > max_checkpoints:
            os.remove(checkpoints[-1])
        with open(f'checkpoints/{self.model_dir}/discriminator_{self.iter}.pt',
                  'wb') as f:
            torch.save(self.discriminator.state_dict(), f)

    def load_discriminator(self, path, load_last=True):
        if load_last:
            try:
                checkpoints = glob.glob(f'{path}/discriminator*.pt')
                path = max(checkpoints, key=os.path.getctime)
            except (ValueError):
                print(f'Directory is empty: {path}')

        try:
            self.discriminator.load_state_dict(torch.load(path))
            self.cuda()
        except (FileNotFoundError):
            print(f'No such file: {path}')

    def load_generator(self, path, load_last=True):
        if load_last:
            try:
                checkpoints = glob.glob(f'{path}/generator*.pt')
                path = max(checkpoints, key=os.path.getctime)
            except (ValueError):
                print(f'Directory is empty: {path}')

        try:
            self.generator.load_state_dict(torch.load(path))
            iter_str = ''.join(filter(lambda x: x.isdigit(), path))
            self._iter = nn.Parameter(torch.tensor(int(iter_str)),
                                      requires_grad=False)
            self.cuda()
        except (FileNotFoundError):
            print(f'No such file: {path}')
Exemplo n.º 6
0
def main():
    ## load std models
    # policy_log_std = torch.load('./model_pkl/policy_net_action_std_model_1.pkl')
    # transition_log_std = torch.load('./model_pkl/transition_net_state_std_model_1.pkl')

    # load expert data
    print(args.data_set_path)
    dataset = ExpertDataSet(args.data_set_path)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=args.expert_batch_size,
                                  shuffle=True,
                                  num_workers=0)
    # define actor/critic/discriminator net and optimizer
    policy = Policy(onehot_action_sections,
                    onehot_state_sections,
                    state_0=dataset.state)
    value = Value()
    discriminator = Discriminator()
    optimizer_policy = torch.optim.Adam(policy.parameters(), lr=args.policy_lr)
    optimizer_value = torch.optim.Adam(value.parameters(), lr=args.value_lr)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=args.discrim_lr)
    discriminator_criterion = nn.BCELoss()
    if write_scalar:
        writer = SummaryWriter(log_dir='runs/' + model_name)

    # load net  models
    if load_model:
        discriminator.load_state_dict(
            torch.load('./model_pkl/Discriminator_model_' + model_name +
                       '.pkl'))
        policy.transition_net.load_state_dict(
            torch.load('./model_pkl/Transition_model_' + model_name + '.pkl'))
        policy.policy_net.load_state_dict(
            torch.load('./model_pkl/Policy_model_' + model_name + '.pkl'))
        value.load_state_dict(
            torch.load('./model_pkl/Value_model_' + model_name + '.pkl'))

        policy.policy_net_action_std = torch.load(
            './model_pkl/Policy_net_action_std_model_' + model_name + '.pkl')
        policy.transition_net_state_std = torch.load(
            './model_pkl/Transition_net_state_std_model_' + model_name +
            '.pkl')
    print('#############  start training  ##############')

    # update discriminator
    num = 0
    for ep in tqdm(range(args.training_epochs)):
        # collect data from environment for ppo update
        policy.train()
        value.train()
        discriminator.train()
        start_time = time.time()
        memory, n_trajs = policy.collect_samples(
            batch_size=args.sample_batch_size)
        # print('sample_data_time:{}'.format(time.time()-start_time))
        batch = memory.sample()
        onehot_state = torch.cat(batch.onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_state = torch.cat(batch.multihot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_state = torch.cat(batch.continuous_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()

        onehot_action = torch.cat(batch.onehot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        multihot_action = torch.cat(batch.multihot_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        continuous_action = torch.cat(batch.continuous_action, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_onehot_state = torch.cat(batch.next_onehot_state, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        next_multihot_state = torch.cat(batch.next_multihot_state,
                                        dim=1).reshape(
                                            n_trajs * args.sample_traj_length,
                                            -1).detach()
        next_continuous_state = torch.cat(
            batch.next_continuous_state,
            dim=1).reshape(n_trajs * args.sample_traj_length, -1).detach()

        old_log_prob = torch.cat(batch.old_log_prob, dim=1).reshape(
            n_trajs * args.sample_traj_length, -1).detach()
        mask = torch.cat(batch.mask,
                         dim=1).reshape(n_trajs * args.sample_traj_length,
                                        -1).detach()
        gen_state = torch.cat((onehot_state, multihot_state, continuous_state),
                              dim=-1)
        gen_action = torch.cat(
            (onehot_action, multihot_action, continuous_action), dim=-1)
        if ep % 1 == 0:
            # if (d_slow_flag and ep % 50 == 0) or (not d_slow_flag and ep % 1 == 0):
            d_loss = torch.empty(0, device=device)
            p_loss = torch.empty(0, device=device)
            v_loss = torch.empty(0, device=device)
            gen_r = torch.empty(0, device=device)
            expert_r = torch.empty(0, device=device)
            for expert_state_batch, expert_action_batch in data_loader:
                noise1 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_state.shape,
                                      device=device)
                noise2 = torch.normal(0,
                                      args.noise_std,
                                      size=gen_action.shape,
                                      device=device)
                noise3 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_state_batch.shape,
                                      device=device)
                noise4 = torch.normal(0,
                                      args.noise_std,
                                      size=expert_action_batch.shape,
                                      device=device)
                gen_r = discriminator(gen_state + noise1, gen_action + noise2)
                expert_r = discriminator(
                    expert_state_batch.to(device) + noise3,
                    expert_action_batch.to(device) + noise4)

                # gen_r = discriminator(gen_state, gen_action)
                # expert_r = discriminator(expert_state_batch.to(device), expert_action_batch.to(device))
                optimizer_discriminator.zero_grad()
                d_loss = discriminator_criterion(gen_r, torch.zeros(gen_r.shape, device=device)) + \
                            discriminator_criterion(expert_r,torch.ones(expert_r.shape, device=device))
                variance = 0.5 * torch.var(gen_r.to(device)) + 0.5 * torch.var(
                    expert_r.to(device))
                total_d_loss = d_loss - 10 * variance
                d_loss.backward()
                # total_d_loss.backward()
                optimizer_discriminator.step()
            if write_scalar:
                writer.add_scalar('d_loss', d_loss, ep)
                writer.add_scalar('total_d_loss', total_d_loss, ep)
                writer.add_scalar('variance', 10 * variance, ep)
        if ep % 1 == 0:
            # update PPO
            noise1 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_state.shape,
                                  device=device)
            noise2 = torch.normal(0,
                                  args.noise_std,
                                  size=gen_action.shape,
                                  device=device)
            gen_r = discriminator(gen_state + noise1, gen_action + noise2)
            #if gen_r.mean().item() < 0.1:
            #    d_stop = True
            #if d_stop and gen_r.mean()
            optimize_iter_num = int(
                math.ceil(onehot_state.shape[0] / args.ppo_mini_batch_size))
            # gen_r = -(1 - gen_r + 1e-10).log()
            for ppo_ep in range(args.ppo_optim_epoch):
                for i in range(optimize_iter_num):
                    num += 1
                    index = slice(
                        i * args.ppo_mini_batch_size,
                        min((i + 1) * args.ppo_mini_batch_size,
                            onehot_state.shape[0]))
                    onehot_state_batch, multihot_state_batch, continuous_state_batch, onehot_action_batch, multihot_action_batch, continuous_action_batch, \
                    old_log_prob_batch, mask_batch, next_onehot_state_batch, next_multihot_state_batch, next_continuous_state_batch, gen_r_batch = \
                        onehot_state[index], multihot_state[index], continuous_state[index], onehot_action[index], multihot_action[index], continuous_action[index], \
                        old_log_prob[index], mask[index], next_onehot_state[index], next_multihot_state[index], next_continuous_state[index], gen_r[
                            index]
                    v_loss, p_loss = ppo_step(
                        policy, value, optimizer_policy, optimizer_value,
                        onehot_state_batch, multihot_state_batch,
                        continuous_state_batch, onehot_action_batch,
                        multihot_action_batch, continuous_action_batch,
                        next_onehot_state_batch, next_multihot_state_batch,
                        next_continuous_state_batch, gen_r_batch,
                        old_log_prob_batch, mask_batch, args.ppo_clip_epsilon)
                    if write_scalar:
                        writer.add_scalar('p_loss', p_loss, ep)
                        writer.add_scalar('v_loss', v_loss, ep)
        policy.eval()
        value.eval()
        discriminator.eval()
        noise1 = torch.normal(0,
                              args.noise_std,
                              size=gen_state.shape,
                              device=device)
        noise2 = torch.normal(0,
                              args.noise_std,
                              size=gen_action.shape,
                              device=device)
        gen_r = discriminator(gen_state + noise1, gen_action + noise2)
        expert_r = discriminator(
            expert_state_batch.to(device) + noise3,
            expert_action_batch.to(device) + noise4)
        gen_r_noise = gen_r.mean().item()
        expert_r_noise = expert_r.mean().item()
        gen_r = discriminator(gen_state, gen_action)
        expert_r = discriminator(expert_state_batch.to(device),
                                 expert_action_batch.to(device))
        if write_scalar:
            writer.add_scalar('gen_r', gen_r.mean(), ep)
            writer.add_scalar('expert_r', expert_r.mean(), ep)
            writer.add_scalar('gen_r_noise', gen_r_noise, ep)
            writer.add_scalar('expert_r_noise', expert_r_noise, ep)
        print('#' * 5 + 'training episode:{}'.format(ep) + '#' * 5)
        print('gen_r_noise', gen_r_noise)
        print('expert_r_noise', expert_r_noise)
        print('gen_r:', gen_r.mean().item())
        print('expert_r:', expert_r.mean().item())
        print('d_loss', d_loss.item())
        # save models
        if model_name is not None:
            torch.save(
                discriminator.state_dict(),
                './model_pkl/Discriminator_model_' + model_name + '.pkl')
            torch.save(policy.transition_net.state_dict(),
                       './model_pkl/Transition_model_' + model_name + '.pkl')
            torch.save(policy.policy_net.state_dict(),
                       './model_pkl/Policy_model_' + model_name + '.pkl')
            torch.save(
                policy.policy_net_action_std,
                './model_pkl/Policy_net_action_std_model_' + model_name +
                '.pkl')
            torch.save(
                policy.transition_net_state_std,
                './model_pkl/Transition_net_state_std_model_' + model_name +
                '.pkl')
            torch.save(value.state_dict(),
                       './model_pkl/Value_model_' + model_name + '.pkl')
        memory.clear_memory()
Exemplo n.º 7
0
    losses = {
        'total': [],
        'kl': [],
        'bce': [],
        'dis': [],
        'gen': [],
        'classifier': [],
        'test': []
    }
    data_length = len(dataloader['train']) * opts.batch_size

    full_time = time()

    for e in range(opts.epochs):
        cvae.train()
        dis.train()

        e_loss = 0
        e_rec_loss = 0
        e_kl_loss = 0
        e_class_loss = 0
        e_dis_loss = 0
        e_gen_loss = 0
        e_classifier_loss = 0
        e_classifier_en_loss = 0

        epoch_time = time()

        for i, data in enumerate(dataloader['train'], 0):
            optimizer_cvae.zero_grad()
            optimizer_classifier.zero_grad()
Exemplo n.º 8
0
def main():

    # parse input size
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    # cudnn.enabled = True
    # gpu = args.gpu

    # create segmentation network
    model = DeepLab(num_classes=args.num_classes)

    # load pretrained parameters
    # if args.restore_from[:4] == 'http' :
    #     saved_state_dict = model_zoo.load_url(args.restore_from)
    # else:
    #     saved_state_dict = torch.load(args.restore_from)

    # only copy the params that exist in current model (caffe-like)
    # new_params = model.state_dict().copy()
    # for name, param in new_params.items():
    #     if name in saved_state_dict and param.size() == saved_state_dict[name].size():
    #         new_params[name].copy_(saved_state_dict[name])
    # model.load_state_dict(new_params)

    model.train()
    model.cpu()
    # model.cuda(args.gpu)
    # cudnn.benchmark = True

    # create discriminator network
    model_D = Discriminator(num_classes=args.num_classes)
    # if args.restore_from_D is not None:
    #     model_D.load_state_dict(torch.load(args.restore_from_D))
    model_D.train()
    model_D.cpu()
    # model_D.cuda(args.gpu)

    # MILESTONE 1
    print("Printing MODELS ...")
    print(model)
    print(model_D)

    # Create directory to save snapshots of the model
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    # Load train data and ground truth labels
    # train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size,
    #                 scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)
    # train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size,
    #                    scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN)

    # trainloader = data.DataLoader(train_dataset,
    #                 batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False)
    # trainloader_gt = data.DataLoader(train_gt_dataset,
    #                 batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False)

    train_dataset = MyCustomDataset()
    train_gt_dataset = MyCustomDataset()

    trainloader = data.DataLoader(train_dataset, batch_size=5, shuffle=True)
    trainloader_gt = data.DataLoader(train_gt_dataset,
                                     batch_size=5,
                                     shuffle=True)

    trainloader_iter = enumerate(trainloader)
    trainloader_gt_iter = enumerate(trainloader_gt)

    # MILESTONE 2
    print("Printing Loaders")
    print(trainloader_iter)
    print(trainloader_gt_iter)

    # optimizer for segmentation network
    optimizer = optim.SGD(model.optim_parameters(args),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=args.learning_rate_D,
                             betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    # MILESTONE 3
    print("Printing OPTIMIZERS ...")
    print(optimizer)
    print(optimizer_D)

    # loss/ bilinear upsampling
    bce_loss = BCEWithLogitsLoss2d()
    interp = nn.Upsample(size=(input_size[1], input_size[0]),
                         mode='bilinear',
                         align_corners=True)

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    for i_iter in range(args.num_steps):

        loss_seg_value = 0
        loss_adv_pred_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter)

        for sub_i in range(args.iter_size):

            # train G

            # don't accumulate grads in D
            for param in model_D.parameters():
                param.requires_grad = False

            # do semi first
            # if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv :
            #     try:
            #         _, batch = next(trainloader_remain_iter)
            #     except:
            #         trainloader_remain_iter = enumerate(trainloader_remain)
            #         _, batch = next(trainloader_remain_iter)

            #     # only access to img
            #     images, _, _, _ = batch
            #     images = Variable(images).cuda(args.gpu)

            #     pred = interp(model(images))
            #     pred_remain = pred.detach()

            #     D_out = interp(model_D(F.softmax(pred)))
            #     D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1)

            #     ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)

            #     loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain))
            #     loss_semi_adv = loss_semi_adv/args.iter_size

            #     #loss_semi_adv.backward()
            #     loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv

            #     if args.lambda_semi <= 0 or i_iter < args.semi_start:
            #         loss_semi_adv.backward()
            #         loss_semi_value = 0
            #     else:
            #         # produce ignore mask
            #         semi_ignore_mask = (D_out_sigmoid < args.mask_T)

            #         semi_gt = pred.data.cpu().numpy().argmax(axis=1)
            #         semi_gt[semi_ignore_mask] = 255

            #         semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size
            #         print('semi ratio: {:.4f}'.format(semi_ratio))

            #         if semi_ratio == 0.0:
            #             loss_semi_value += 0
            #         else:
            #             semi_gt = torch.FloatTensor(semi_gt)

            #             loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu)
            #             loss_semi = loss_semi/args.iter_size
            #             loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi
            #             loss_semi += loss_semi_adv
            #             loss_semi.backward()

            # else:
            #     loss_semi = None
            #     loss_semi_adv = None

            # train with source

            try:
                _, batch = next(trainloader_iter)
            except:
                trainloader_iter = enumerate(trainloader)
                _, batch = next(trainloader_iter)

            images, labels, _, _ = batch
            images = Variable(images).cpu()
            # images = Variable(images).cuda(args.gpu)
            ignore_mask = (labels.numpy() == 255)

            # segmentation prediction
            pred = interp(model(images))
            # (spatial multi-class) cross entropy loss
            loss_seg = loss_calc(pred, labels)
            # loss_seg = loss_calc(pred, labels, args.gpu)

            # discriminator prediction
            D_out = interp(model_D(F.softmax(pred)))
            # adversarial loss
            loss_adv_pred = bce_loss(D_out,
                                     make_D_label(gt_label, ignore_mask))

            # multi-task loss
            # lambda_adv - weight for minimizing loss
            loss = loss_seg + args.lambda_adv_pred * loss_adv_pred

            # loss normalization
            loss = loss / args.iter_size

            # back propagation
            loss.backward()

            loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size
            loss_adv_pred_value += loss_adv_pred.data.cpu().numpy(
            ) / args.iter_size

            # train D

            # bring back requires_grad
            for param in model_D.parameters():
                param.requires_grad = True

            # train with pred
            pred = pred.detach()

            # if args.D_remain:
            #     pred = torch.cat((pred, pred_remain), 0)
            #     ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0)

            D_out = interp(model_D(F.softmax(pred)))
            loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

            # train with gt
            # get gt labels
            try:
                _, batch = next(trainloader_gt_iter)
            except:
                trainloader_gt_iter = enumerate(trainloader_gt)
                _, batch = next(trainloader_gt_iter)

            _, labels_gt, _, _ = batch
            D_gt_v = Variable(one_hot(labels_gt)).cpu()
            # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu)
            ignore_mask_gt = (labels_gt.numpy() == 255)

            D_out = interp(model_D(D_gt_v))
            loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt))
            loss_D = loss_D / args.iter_size / 2
            loss_D.backward()
            loss_D_value += loss_D.data.cpu().numpy()

        optimizer.step()
        optimizer_D.step()

        print('exp = {}'.format(args.snapshot_dir))
        print(
            'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'
            .format(i_iter, args.num_steps, loss_seg_value,
                    loss_adv_pred_value, loss_D_value, loss_semi_value,
                    loss_semi_adv_value))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir,
                         'VOC_' + str(args.num_steps) + '_D.pth'))
            break

        if i_iter % args.save_pred_every == 0 and i_iter != 0:
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth'))
            torch.save(
                model_D.state_dict(),
                osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')