Example #1
0
File: viar.py Project: GedamuA/VIAR
def build_models(args, device='cuda'):
    models = {}

    models['encodercnn'] = CNN(input_shape=RGB_INPUT_SHAPE,
                               model_name=args.encoder_cnn_model).to(device)

    models['encoder'] = Encoder(input_shape=models['encodercnn'].out_size,
                                encoder_block='convbilstm',
                                hidden_size=args.encoder_hid_size).to(device)

    models['crossviewdecodercnn'] = CNN(input_shape=DEPTH_INPUT_SHAPE,
                                        model_name=args.encoder_cnn_model,
                                        input_channel=1).to(device)

    crossviewdecoder_in_size = list(models['crossviewdecodercnn'].out_size)
    crossviewdecoder_in_size[0] = crossviewdecoder_in_size[0] * 3
    crossviewdecoder_in_size = torch.Size(crossviewdecoder_in_size)
    models['crossviewdecoder'] = CrossViewDecoder(
        input_shape=crossviewdecoder_in_size).to(device)

    models['reconstructiondecoder'] = ReconstructionDecoder(
        input_shape=models['encoder'].out_size[1:]).to(device)

    models['viewclassifier'] = ViewClassifier(
        input_size=reduce(operator.mul, models['encoder'].out_size[1:]),
        num_classes=5,
        reverse=(not args.disable_grl)).to(device)

    return models
Example #2
0
    def __init__(self, args):

        self.z_dim = args.z_dim
        self.decay_rate = args.decay_rate
        self.learning_rate = args.learning_rate
        self.model_name = args.model_name
        self.batch_size = args.batch_size

        #initialize networks
        self.Generator = Generator(self.z_dim).cuda()
        self.Encoder = Encoder(self.z_dim).cuda()
        self.Discriminator = Discriminator().cuda()

        #set optimizers for all networks
        self.optimizer_G_E = torch.optim.Adam(
            list(self.Generator.parameters()) +
            list(self.Encoder.parameters()),
            lr=self.learning_rate,
            betas=(0.5, 0.999))

        self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(),
                                            lr=self.learning_rate,
                                            betas=(0.5, 0.999))

        #initialize network weights
        self.Generator.apply(weights_init)
        self.Encoder.apply(weights_init)
        self.Discriminator.apply(weights_init)
Example #3
0
    def __init__(self, alpha=1., beta=1., gamma=0.1):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        with self.init_scope():
            self.encoder = Encoder()
            self.local_disc = LocalDiscriminator()
            self.global_disc = GlobalDiscriminator()
            self.prior_disc = PriorDiscriminator()
Example #4
0
    def __init__(self, embedding, output_size=2):
        super(Decoder, self).__init__()
        self.output_size = output_size

        self.positions = nn.Linear(config.hidden_size * 2, output_size)
        self.encoder = Encoder(embedding, config.batch_size,
                               config.hidden_size, config.num_encoder_layers,
                               config.encoder_bidirectional)
        self.softmax = nn.Softmax(dim=1)
        self.qn_linear = nn.Linear(config.hidden_size, config.hidden_size)
        self.criterion = nn.MSELoss()
Example #5
0
    def __init__(self,
                 z_dim=50,
                 hidden_dim=400,
                 enc_kernel1=5,
                 enc_kernel2=5,
                 use_cuda=False):
        super(VAE, self).__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim, enc_kernel1, enc_kernel2)
        self.decoder = Decoder(z_dim, hidden_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
Example #6
0
    def __init__(self,
                 in_size: int,
                 ts_size: int = 100,
                 latent_dim: int = 20,
                 lr: float = 0.0005,
                 weight_decay: float = 1e-6,
                 iterations_critic: int = 5,
                 gamma: float = 10,
                 weighted: bool = True,
                 use_gru=False):
        super(TadGAN, self).__init__()
        self.in_size = in_size
        self.latent_dim = latent_dim
        self.lr = lr
        self.weight_decay = weight_decay
        self.iterations_critic = iterations_critic
        self.gamma = gamma
        self.weighted = weighted

        self.hparams = {
            'lr': self.lr,
            'weight_decay': self.weight_decay,
            'iterations_critic': self.iterations_critic,
            'gamma': self.gamma
        }

        self.encoder = Encoder(in_size,
                               ts_size=ts_size,
                               out_size=self.latent_dim,
                               batch_first=True,
                               use_gru=use_gru)
        self.generator = Generator(use_gru=use_gru)
        self.critic_x = CriticX(in_size=in_size)
        self.critic_z = CriticZ()

        self.encoder.apply(init_weights)
        self.generator.apply(init_weights)
        self.critic_x.apply(init_weights)
        self.critic_z.apply(init_weights)

        if self.logger is not None:
            self.logger.log_hyperparams(self.hparams)

        self.y_hat = []
        self.index = []
        self.critic = []
Example #7
0
def main(FLAGS):
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    device = 'cuda:0'

    decoder.to(device)
    encoder.to(device)

    tsne = TSNE(2)

    mnist = DataLoader(
        datasets.MNIST(root='mnist',
                       download=True,
                       train=False,
                       transform=transform_config))
    s_dict = {}
    with torch.no_grad():
        for i, (image, label) in enumerate(mnist):
            label = int(label)
            print(i, label)
            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                image.to(device))
            s_dict.setdefault(label, []).append(class_latent_space_1)

    s_all = []
    for label in range(10):
        s_all.extend(s_dict[label])

    s_all = torch.cat(s_all)
    s_all = s_all.view(s_all.shape[0], -1).cpu()

    s_2d = tsne.fit_transform(s_all)

    np.savez('s_2d.npz', s_2d=s_2d)
Example #8
0
 def __init__(self,
              latents_sizes,
              latents_names,
              img_dim=4096,
              label_dim=114,
              latent_dim=200,
              use_CUDA=False):
     super(VAE, self).__init__()
     #creating networks
     self.encoder = Encoder(img_dim, label_dim, latent_dim)
     self.decoder = Decoder(img_dim, label_dim, latent_dim)
     self.img_dim = img_dim
     self.label_dim = label_dim
     self.latent_dim = latent_dim
     self.latents_sizes = latents_sizes
     self.latents_names = latents_names
     if use_CUDA:
         self.cuda()
     self.use_CUDA = use_CUDA
Example #9
0
    def __init__(self, input_type='image', representation_type='image', output_type=['image'], s_type='classes', input_dim=104, \
            representation_dim=8, output_dim=[1], s_dim=1, problem='privacy', beta=1.0, gamma=1.0, prior_type='Gaussian'):
        super(VPAF, self).__init__()

        self.problem = problem
        self.param = gamma if self.problem == 'privacy' else beta
        self.input_type = input_type
        self.representation_type = representation_type
        self.output_type = output_type
        self.output_dim = output_dim
        self.s_type = s_type
        self.prior_type = prior_type

        self.encoder = Encoder(input_type, representation_type, input_dim,
                               representation_dim)
        self.decoder = Decoder(representation_type,
                               output_type,
                               representation_dim,
                               output_dim,
                               s_dim=s_dim)
    def __init__(self, env, tau=0.1, gamma=0.9, epsilon=1.0):
        self.env = env
        self.tau = tau
        self.gamma = gamma
        self.embedding_size = 30
        self.hidden_size = 30
        self.obs_shape = self.env.get_obs().shape
        self.action_shape = 40 // 5
        if args.encoding == "onehot":
            self.encoder = OneHot(
                args.bins,
                self.env.all_questions + self.env.held_out_questions,
                self.hidden_size).to(DEVICE)
        else:
            self.encoder = Encoder(self.embedding_size,
                                   self.hidden_size).to(DEVICE)

        self.model = DQN(self.obs_shape, self.action_shape,
                         self.encoder).to(DEVICE)
        self.target_model = DQN(self.obs_shape, self.action_shape,
                                self.encoder).to(DEVICE)
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.epsilon = epsilon
        if os.path.exists(MODEL_FILE):
            checkpoint = torch.load(MODEL_FILE)
            self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.target_model.load_state_dict(
                checkpoint['target_model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.epsilon = checkpoint['epsilon']

        # hard copy model parameters to target model parameters
        for target_param, param in zip(self.model.parameters(),
                                       self.target_model.parameters()):
            target_param.data.copy_(param)
Example #11
0
    def __init__(self, env, tau=0.05, gamma=0.9, epsilon=1.0):
        super().__init__()
        self.env = env
        self.tau = tau
        self.gamma = gamma
        self.embedding_size = 64
        self.hidden_size = 64
        self.obs_shape = self.env.get_obs().shape
        self.action_shape = 40 // 5
        self.encoder = Encoder(self.embedding_size,
                               self.hidden_size).to(DEVICE)

        self.model = DQN(self.obs_shape, self.action_shape,
                         self.encoder).to(DEVICE)
        self.target_model = DQN(self.obs_shape, self.action_shape,
                                self.encoder).to(DEVICE)
        self.target_model.eval()
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.epsilon = epsilon

        # hard copy model parameters to target model parameters
        for target_param, param in zip(self.model.parameters(),
                                       self.target_model.parameters()):
            target_param.data.copy_(param)
Example #12
0
import sys, os
sys.path.append("..")
sys.path.extend([
    os.path.join(root, name) for root, dirs, _ in os.walk("../")
    for name in dirs
])

from _config import NNConfig
from networks import CNN, LSTM, Encoder, Decoder

nnconfig = NNConfig()
nnconfig.show()

cnn = CNN("cnn_layer1")
lstm = LSTM("lstm_layer1")
cnn.show()
lstm.show()

encoder = Encoder(cnn)
decoder = Decoder(lstm)
Example #13
0
    # sigma_p_inv: (n_dim, n_frames, n_frames), det_p: (d)
    # sigma_q: (batch_size, n_dim, n_frames, n_frames), mu_q: (batch_size, d, nlen)

    l1 = torch.einsum('kij,mkji->mk', sigma_p_inv,
                      sigma_q)  # tr(sigma_p_inv sigma_q)
    l2 = torch.einsum('mki,mki->mk', mu_p - mu_q,
                      torch.einsum('kij,mkj->mki', sigma_p_inv,
                                   mu_p - mu_q))  # <mu_q, sigma_p_inv, mu_q>
    loss = torch.sum(l1 + l2 + torch.log(det_p) - torch.log(det_q), dim=1)
    return loss


if (__name__ == '__main__'):

    # model definition
    encoder = Encoder()
    encoder.apply(weights_init)

    decoder = Decoder()
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if LOAD_SAVED:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', ENCODER_SAVE)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', DECODER_SAVE)))

    # loss definition
    mse_loss = nn.MSELoss()
Example #14
0
    def __init__(self, hparams):
        super().__init__()

        # output path
        self.output_path = "outputs/"

        # self.logger.log_hyperparams(hparams)  # log hyperparameters
        self.hparams = hparams

        self.files = ["../data.csv"]
        self.n = len(self.files)

        # load data
        self.data = [torch.from_numpy(np.genfromtxt(file, delimiter=',').transpose()[1:, 1:]).float() for file in self.files]

        self.datasets = [torch.utils.data.TensorDataset(data) for data in self.data]

        self.test_size = 60

        self.train_dataset, self.test_dataset = zip(*(
            torch.utils.data.random_split(dataset, (len(dataset) - self.test_size, self.test_size))
            for dataset in self.datasets))

        input_dim = 3000

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # distance matrix
        print("computing distance matrices")
        self.distance_matrix_train = self.distance_matrix(torch.stack(list(self.UnTensorDataset(self.train_dataset[0]))))
        if not self.hparams.correlation_distance_loss:
            self.distance_total_train = self.distance_matrix_train.sum(1)
            # self.distance_matrix_train_norm = self.distance_matrix_train / self.distance_total_train.unsqueeze(1)
            self.distance_matrix_train_norm = self.distance_matrix_train / self.distance_total_train.sum()
        print("done")

        # define VAEs
        self.E = [Encoder(input_dim, self.hparams.latent_dim, self.hparams.hypersphere).to(device) for _ in range(self.n)]  # hparams available for activation and dropout

        self.G = [Generator(self.hparams.latent_dim, input_dim).to(device) for _ in range(self.n)]

        # share weights
        if self.hparams.share_weights:
            for E in self.E[1:]:
                E.s1 = self.E[0].s1
                E.s2m = self.E[0].s2m
                E.s2v = self.E[0].s2v
            for G in self.G[1:]:
                G.s1 = self.G[0].s1
                G.s2 = self.G[0].s2

        # define discriminators
        if self.hparams.separate_D:
            self.D = [[Discriminator(input_dim).to(device) if i != j else None
                       for j in range(self.n)] for i in range(self.n)]
        else:
            self.D = [[Discriminator(input_dim).to(device) for _ in range(self.n)]] * self.n

        # named modules to make pytorch lightning happy
        self.E0 = self.E[0]
        self.G0 = self.G[0]

        # hyperspherical distribution
        self.p_z = HypersphericalUniform(self.hparams.latent_dim - 1, device=device) \
            if self.hparams.hypersphere else None

        # cache
        self.prev_g_loss = None
        self.current_z = self.forward(torch.stack(list(self.UnTensorDataset(self.train_dataset[0]))).to(device), first=True).z_a.detach()
    summand = px_cond * p_s * p_g / q_s / q_g
    print(summand)

    likelihood = torch.sum(summand) / summand.size(0)
    



FLAGS = parser.parse_args()

if __name__ == '__main__':
    """
    model definitions
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)

    encoder.load_state_dict(
        torch.load(os.path.join('checkpoints', FLAGS.encoder_save), map_location=lambda storage, loc: storage))
    decoder.load_state_dict(
        torch.load(os.path.join('checkpoints', FLAGS.decoder_save), map_location=lambda storage, loc: storage))

    encoder.cuda()
    decoder.cuda()

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load data set and create data loader instance
    '''
parser.add_argument("--print_interval", type=int, default=100, help="interval of loss printing")
parser.add_argument("--dataroot", default="", help="path to dataset")
parser.add_argument("--dataset", default="cifar10", help="folder | cifar10 | mnist")
parser.add_argument("--abnormal_class", default="airplane", help="Anomaly class idx for mnist and cifar datasets")
parser.add_argument("--out", default="ckpts", help="checkpoint directory")
parser.add_argument("--device", default="cuda", help="device: cuda | cpu")
parser.add_argument("--G_path", default="ckpts/G_epoch19.pt", help="path to trained state dict of generator")
parser.add_argument("--D_path", default="ckpts/D_epoch19.pt", help="path to trained state dict of discriminator")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

generator = Generator(dim = 64, zdim=opt.latent_dim, nc=opt.channels)
discriminator = Discriminator(dim = 64, zdim=opt.latent_dim, nc=opt.channels,out_feat=True)
encoder = Encoder(dim = 64, zdim=opt.latent_dim, nc=opt.channels)

generator.load_state_dict(torch.load(opt.G_path))
discriminator.load_state_dict(torch.load(opt.D_path))
generator.to(opt.device)
encoder.to(opt.device)
discriminator.to(opt.device)

encoder.train()
discriminator.train()

dataloader = load_data(opt)

generator.eval()

Tensor = torch.cuda.FloatTensor if  opt.device == 'cuda' else torch.FloatTensor
Example #17
0
def train(opt):
    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id)
                          if opt.gpu_id >= 0 else torch.device('cpu'))

    #### dataset
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()
    print("The number of training images = %d." % len(data_set))

    #### initialize models
    ## declaration
    E_a2Zb = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Zb2b = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Zb2Za = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_b = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    E_b2Za = Encoder(input_nc=opt.input_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type,
                     use_dropout=not opt.no_dropout,
                     n_blocks=9)
    G_Za2a = Decoder(output_nc=opt.output_nc,
                     ngf=opt.ngf,
                     norm_type=opt.norm_type)
    T_Za2Zb = LatentTranslator(n_channels=256,
                               norm_type=opt.norm_type,
                               use_dropout=not opt.no_dropout)
    D_a = Discriminator(input_nc=opt.input_nc,
                        ndf=opt.ndf,
                        n_layers=opt.n_layers,
                        norm_type=opt.norm_type)

    ## initialization
    E_a2Zb = init_net(E_a2Zb, init_type=opt.init_type).to(device)
    G_Zb2b = init_net(G_Zb2b, init_type=opt.init_type).to(device)
    T_Zb2Za = init_net(T_Zb2Za, init_type=opt.init_type).to(device)
    D_b = init_net(D_b, init_type=opt.init_type).to(device)

    E_b2Za = init_net(E_b2Za, init_type=opt.init_type).to(device)
    G_Za2a = init_net(G_Za2a, init_type=opt.init_type).to(device)
    T_Za2Zb = init_net(T_Za2Zb, init_type=opt.init_type).to(device)
    D_a = init_net(D_a, init_type=opt.init_type).to(device)
    print(
        "+------------------------------------------------------+\nFinish initializing networks."
    )

    #### optimizer and criterion
    ## criterion
    criterionGAN = GANLoss(opt.gan_mode).to(device)
    criterionZId = nn.L1Loss()
    criterionIdt = nn.L1Loss()
    criterionCTC = nn.L1Loss()
    criterionZCyc = nn.L1Loss()

    ## optimizer
    optimizer_G = torch.optim.Adam(itertools.chain(E_a2Zb.parameters(),
                                                   G_Zb2b.parameters(),
                                                   T_Zb2Za.parameters(),
                                                   E_b2Za.parameters(),
                                                   G_Za2a.parameters(),
                                                   T_Za2Zb.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))
    optimizer_D = torch.optim.Adam(itertools.chain(D_a.parameters(),
                                                   D_b.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2))

    ## scheduler
    scheduler = [
        get_scheduler(optimizer_G, opt),
        get_scheduler(optimizer_D, opt)
    ]

    print(
        "+------------------------------------------------------+\nFinish initializing the optimizers and criterions."
    )

    #### global variables
    checkpoints_pth = os.path.join(opt.checkpoints, opt.name)
    if os.path.exists(checkpoints_pth) is not True:
        os.mkdir(checkpoints_pth)
        os.mkdir(os.path.join(checkpoints_pth, 'images'))
    record_fh = open(os.path.join(checkpoints_pth, 'records.txt'),
                     'w',
                     encoding='utf-8')
    loss_names = [
        'GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B',
        'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B'
    ]

    fake_A_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images
    fake_B_pool = ImagePool(
        opt.pool_size
    )  # create image buffer to store previously generated images

    print(
        "+------------------------------------------------------+\nFinish preparing the other works."
    )
    print(
        "+------------------------------------------------------+\nNow training is beginning .."
    )
    #### training
    cur_iter = 0
    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()  # timer for entire epoch

        for i, data in enumerate(data_set):
            ## setup inputs
            real_A = data['A'].to(device)
            real_B = data['B'].to(device)

            ## forward
            # image cycle / GAN
            latent_B = E_a2Zb(real_A)  #-> a -> Zb     : E_a2b(a)
            fake_B = G_Zb2b(latent_B)  #-> Zb -> b'    : G_b(E_a2b(a))
            latent_A = E_b2Za(real_B)  #-> b -> Za     : E_b2a(b)
            fake_A = G_Za2a(latent_A)  #-> Za -> a'    : G_a(E_b2a(b))

            # Idt
            '''
            rec_A = G_Za2a(E_b2Za(fake_B))          #-> b' -> Za' -> rec_a  : G_a(E_b2a(fake_b))
            rec_B = G_Zb2b(E_a2Zb(fake_A))          #-> a' -> Zb' -> rec_b  : G_b(E_a2b(fake_a))
            '''
            idt_latent_A = E_b2Za(real_A)  #-> a -> Za        : E_b2a(a)
            idt_A = G_Za2a(idt_latent_A)  #-> Za -> idt_a    : G_a(E_b2a(a))
            idt_latent_B = E_a2Zb(real_B)  #-> b -> Zb        : E_a2b(b)
            idt_B = G_Zb2b(idt_latent_B)  #-> Zb -> idt_b    : G_b(E_a2b(b))

            # ZIdt
            T_latent_A = T_Zb2Za(latent_B)  #-> Zb -> Za''  : T_b2a(E_a2b(a))
            T_rec_A = G_Za2a(
                T_latent_A)  #-> Za'' -> a'' : G_a(T_b2a(E_a2b(a)))
            T_latent_B = T_Za2Zb(latent_A)  #-> Za -> Zb''  : T_a2b(E_b2a(b))
            T_rec_B = G_Zb2b(
                T_latent_B)  #-> Zb'' -> b'' : G_b(T_a2b(E_b2a(b)))

            # CTC
            T_idt_latent_B = T_Za2Zb(idt_latent_A)  #-> a -> T_a2b(E_b2a(a))
            T_idt_latent_A = T_Zb2Za(idt_latent_B)  #-> b -> T_b2a(E_a2b(b))

            # ZCyc
            TT_latent_B = T_Za2Zb(T_latent_A)  #-> T_a2b(T_b2a(E_a2b(a)))
            TT_latent_A = T_Zb2Za(T_latent_B)  #-> T_b2a(T_a2b(E_b2a(b)))

            ### optimize parameters
            ## Generator updating
            set_requires_grad(
                [D_b, D_a],
                False)  #-> set Discriminator to require no gradient
            optimizer_G.zero_grad()
            # GAN loss
            loss_G_A = criterionGAN(D_b(fake_B), True)
            loss_G_B = criterionGAN(D_a(fake_A), True)
            loss_GAN = loss_G_A + loss_G_B
            # Idt loss
            loss_idt_A = criterionIdt(idt_A, real_A)
            loss_idt_B = criterionIdt(idt_B, real_B)
            loss_Idt = loss_idt_A + loss_idt_B
            # Latent cross-identity loss
            loss_Zid_A = criterionZId(T_rec_A, real_A)
            loss_Zid_B = criterionZId(T_rec_B, real_B)
            loss_Zid = loss_Zid_A + loss_Zid_B
            # Latent cross-translation consistency
            loss_CTC_A = criterionCTC(T_idt_latent_A, latent_A)
            loss_CTC_B = criterionCTC(T_idt_latent_B, latent_B)
            loss_CTC = loss_CTC_B + loss_CTC_A
            # Latent cycle consistency
            loss_ZCyc_A = criterionZCyc(TT_latent_A, latent_A)
            loss_ZCyc_B = criterionZCyc(TT_latent_B, latent_B)
            loss_ZCyc = loss_ZCyc_B + loss_ZCyc_A

            loss_G = opt.lambda_gan * loss_GAN + opt.lambda_idt * loss_Idt + opt.lambda_zid * loss_Zid + opt.lambda_ctc * loss_CTC + opt.lambda_zcyc * loss_ZCyc

            # backward and gradient updating
            loss_G.backward()
            optimizer_G.step()

            ## Discriminator updating
            set_requires_grad([D_b, D_a],
                              True)  # -> set Discriminator to require gradient
            optimizer_D.zero_grad()

            # backward D_b
            fake_B_ = fake_B_pool.query(fake_B)
            #-> real_B, fake_B
            pred_real_B = D_b(real_B)
            loss_D_real_B = criterionGAN(pred_real_B, True)

            pred_fake_B = D_b(fake_B_)
            loss_D_fake_B = criterionGAN(pred_fake_B, False)

            loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
            loss_D_B.backward()

            # backward D_a
            fake_A_ = fake_A_pool.query(fake_A)
            #-> real_A, fake_A
            pred_real_A = D_a(real_A)
            loss_D_real_A = criterionGAN(pred_real_A, True)

            pred_fake_A = D_a(fake_A_)
            loss_D_fake_A = criterionGAN(pred_fake_A, False)

            loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
            loss_D_A.backward()

            # update the gradients
            optimizer_D.step()

            ### validate here, both qualitively and quantitatively
            ## record the losses
            if cur_iter % opt.log_freq == 0:
                # loss_names = ['GAN_A', 'Adv_A', 'Idt_A', 'CTC_A', 'ZId_A', 'ZCyc_A', 'GAN_B', 'Adv_B', 'Idt_B', 'CTC_B', 'ZId_B', 'ZCyc_B']
                losses = [
                    loss_G_A.item(),
                    loss_D_A.item(),
                    loss_idt_A.item(),
                    loss_CTC_A.item(),
                    loss_Zid_A.item(),
                    loss_ZCyc_A.item(),
                    loss_G_B.item(),
                    loss_D_B.item(),
                    loss_idt_B.item(),
                    loss_CTC_B.item(),
                    loss_Zid_B.item(),
                    loss_ZCyc_B.item()
                ]
                # record
                line = ''
                for loss in losses:
                    line += '{} '.format(loss)
                record_fh.write(line[:-1] + '\n')
                # print out
                print('Epoch: %3d/%3dIter: %9d--------------------------+' %
                      (epoch, opt.epoch, i))
                field_names = loss_names[:len(loss_names) // 2]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[:len(field_names)])
                print(table.get_string(reversesort=True))

                field_names = loss_names[len(loss_names) // 2:]
                table = PrettyTable(field_names=field_names)
                for l_n in field_names:
                    table.align[l_n] = 'm'
                table.add_row(losses[-len(field_names):])
                print(table.get_string(reversesort=True))

            ## visualize
            if cur_iter % opt.vis_freq == 0:
                if opt.gpu_id >= 0:
                    real_A = real_A.cpu().data
                    real_B = real_B.cpu().data
                    fake_A = fake_A.cpu().data
                    fake_B = fake_B.cpu().data
                    idt_A = idt_A.cpu().data
                    idt_B = idt_B.cpu().data
                    T_rec_A = T_rec_A.cpu().data
                    T_rec_B = T_rec_B.cpu().data

                plt.subplot(241), plt.title('real_A'), plt.imshow(
                    tensor2image_RGB(real_A[0, ...]))
                plt.subplot(242), plt.title('fake_B'), plt.imshow(
                    tensor2image_RGB(fake_B[0, ...]))
                plt.subplot(243), plt.title('idt_A'), plt.imshow(
                    tensor2image_RGB(idt_A[0, ...]))
                plt.subplot(244), plt.title('L_idt_A'), plt.imshow(
                    tensor2image_RGB(T_rec_A[0, ...]))

                plt.subplot(245), plt.title('real_B'), plt.imshow(
                    tensor2image_RGB(real_B[0, ...]))
                plt.subplot(246), plt.title('fake_A'), plt.imshow(
                    tensor2image_RGB(fake_A[0, ...]))
                plt.subplot(247), plt.title('idt_B'), plt.imshow(
                    tensor2image_RGB(idt_B[0, ...]))
                plt.subplot(248), plt.title('L_idt_B'), plt.imshow(
                    tensor2image_RGB(T_rec_B[0, ...]))

                plt.savefig(
                    os.path.join(checkpoints_pth, 'images',
                                 '%03d_%09d.jpg' % (epoch, i)))

            cur_iter += 1
            #break #-> debug

        ## till now, we finish one epoch, try to update the learning rate
        update_learning_rate(schedulers=scheduler,
                             opt=opt,
                             optimizer=optimizer_D)
        ## save the model
        if epoch % opt.ckp_freq == 0:
            #-> save models
            # torch.save(model.state_dict(), PATH)
            #-> load in models
            # model.load_state_dict(torch.load(PATH))
            # model.eval()
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.cpu()
                G_Zb2b = G_Zb2b.cpu()
                T_Zb2Za = T_Zb2Za.cpu()
                D_b = D_b.cpu()

                E_b2Za = E_b2Za.cpu()
                G_Za2a = G_Za2a.cpu()
                T_Za2Zb = T_Za2Zb.cpu()
                D_a = D_a.cpu()
                '''
                torch.save( E_a2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
                torch.save( G_Zb2b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_b.pth' % epoch))
                torch.save(T_Zb2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
                torch.save(    D_b.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_b.pth' % epoch))

                torch.save( E_b2Za.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
                torch.save( G_Za2a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-G_a.pth' % epoch))
                torch.save(T_Za2Zb.cpu().state_dict(), os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
                torch.save(    D_a.cpu().state_dict(), os.path.join(checkpoints_pth,   'epoch_%3d-D_a.pth' % epoch))
                '''
            torch.save(
                E_a2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_a2b.pth' % epoch))
            torch.save(
                G_Zb2b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_b.pth' % epoch))
            torch.save(
                T_Zb2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_b2a.pth' % epoch))
            torch.save(
                D_b.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_b.pth' % epoch))

            torch.save(
                E_b2Za.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-E_b2a.pth' % epoch))
            torch.save(
                G_Za2a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-G_a.pth' % epoch))
            torch.save(
                T_Za2Zb.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-T_a2b.pth' % epoch))
            torch.save(
                D_a.state_dict(),
                os.path.join(checkpoints_pth, 'epoch_%3d-D_a.pth' % epoch))
            if opt.gpu_id >= 0:
                E_a2Zb = E_a2Zb.to(device)
                G_Zb2b = G_Zb2b.to(device)
                T_Zb2Za = T_Zb2Za.to(device)
                D_b = D_b.to(device)

                E_b2Za = E_b2Za.to(device)
                G_Za2a = G_Za2a.to(device)
                T_Za2Zb = T_Za2Zb.to(device)
                D_a = D_a.to(device)
            print("+Successfully saving models in epoch: %3d.-------------+" %
                  epoch)
        #break #-> debug
    record_fh.close()
    print("≧◔◡◔≦ Congratulation! Finishing the training!")
Example #18
0
	def __init__(self, args, lr=0.1, latent_dim=8, lambda_latent=0.5,
					lambda_kl= 0.001, lambda_recon= 10, is_train = True,  ):
		## Parameters 
		self.batch_size = args.batch_size
		self.latent_dim = latent_dim
		self.image_size = args.img_size
		self.lambda_kl = lambda_kl 
		self.lambda_recon = lambda_recon
		self.lambda_latent = lambda_latent
		self.is_train = tf.placeholder(tf.bool, name= 'is_training')
		self.lr = tf.placeholder(tf.float32, name='learning_rate')
		self.A = tf.placeholder(tf.float32, [self.batch_size, self.image_size,
								self.image_size, 3], name= 'A') 
		self.B = tf.placeholder(tf.float32, [self.batch_size, self.image_size,
								self.image_size, 3], name= 'B')
		self.z = tf.placeholder(tf.float32, [self.batch_size, self.latent_dim], 
								name= 'z')

		## Augmentation
		def aug_img(image):
			aug_strength = 30
			aug_size = self.image_size + aug_strength
			image_resized = tf.image.resize_images(image, [aug_size, aug_size])
			image_cropped = tf.random_crop(image_resized, [self.batch_size, self.image_size,
								self.image_size, 3])
			## work-around as tf-flip doesn't support 4D-batch
			image_flipped = tf.map_fn(lambda image_iter: tf.image.random_flip_left_right(image_iter), image_cropped)
			return image_flipped
		A = tf.cond(self.is_train,
					 lambda: aug_img(self.A), lambda: self.A)
		B = tf.cond(self.is_train, 
					lambda: aug_img(self.B), lambda: self.B)
		## Generator
		with tf.variable_scope('generator'):
			Gen = Generator(self.image_size, self.is_train)

		## Discriminator
		with tf.variable_scope('discriminator'):
			Disc = Discriminator(self.image_size, self.is_train)

		## Encoder
		with tf.variable_scope('encoder'):
			Enc = Encoder(self.image_size, self.is_train, self.latent_dim)

		## cVAE-GAN
		with tf.variable_scope('encoder'):
			z_enc, z_enc_mu, z_enc_log_sigma = Enc(B)
		
		with tf.variable_scope('generator'):
			self.B_hat_enc = Gen(A, z_enc)

		## cLR-GAN 
		with tf.variable_scope('generator', reuse=True):
			self.B_hat = Gen(A, self.z)
		with tf.variable_scope('encoder', reuse= True):
			z_hat, z_hat_mu, z_hat_log_sigma = Enc(self.B_hat)

		## Disc
		with tf.variable_scope('discriminator'):
			self.real = Disc(B)
		with tf.variable_scope('discriminator', reuse=True):
			self.fake = Disc(self.B_hat)
			self.fake_enc = Disc(self.B_hat_enc)

		## losses
		self.vae_gan_cost = tf.reduce_mean(tf.squared_difference(self.real, 0.9)) + \
						tf.reduce_mean(tf.square(self.fake_enc))
		self.recon_img_cost = tf.reduce_mean(tf.abs(B - self.B_hat_enc))
		self.gan_cost = tf.reduce_mean(tf.squared_difference(self.real, 0.9)) + \
					tf.reduce_mean(tf.square(self.fake))
		self.recon_latent_cost = tf.reduce_mean(tf.abs(self.z-z_hat))
		self.kl_div_cost =  -0.5*tf.reduce_mean(1 + 2*z_enc_log_sigma - z_enc_mu**2 -\
							tf.exp(2* z_enc_log_sigma))
		self.vec_cost = [self.vae_gan_cost, self.recon_img_cost, self.gan_cost, self.recon_latent_cost, 
						self.kl_div_cost]
		weight_vec = [1, -self.lambda_recon, 1, -self.lambda_latent, self.lambda_kl]

		self.cost = tf.reduce_sum([self.vec_cost[i]* weight_vec[i] for i in range(len(self.vec_cost)) ])

		## Optimizers
		update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
		self.optim_gen = tf.train.AdamOptimizer(self.lr, beta1=0.5)
		self.optim_disc = tf.train.AdamOptimizer(self.lr, beta1=0.5)
		self.optim_enc = tf.train.AdamOptimizer(self.lr, beta1=0.5)
		
		## Collecting the trainalbe variables
		gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='bc_gan/generator')
		disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='bc_gan/discriminator')
		enc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='bc_gan/encoder')

		## Defining the training operation 
		with tf.control_dependencies(update_ops):
			self.train_op_gen = self.optim_gen.minimize(-self.cost, var_list=gen_vars)
			self.train_op_disc = self.optim_disc.minimize(self.cost, var_list= disc_vars)
			self.train_op_enc = self.optim_enc.minimize(-self.cost, var_list=enc_vars)

		## Joing the training ops 
		self.train_ops = [self.train_op_gen, self.train_op_disc, self.train_op_enc]
		## Summary Create
		def summary_create(self):
			## Image summaries
			tf.summary.image('A', self.A[0:1])
			tf.summary.image('B', self.B[0:1])
			tf.summary.image('B^', self.B_hat[0:1])
			tf.summary.image('B^-enc', self.B_hat_enc[0:1])
			## GEN - DISC summaries - min max game  
			tf.summary.scalar('fake', tf.reduce_mean(self.fake))
			tf.summary.scalar('fake_enc', tf.reduce_mean(self.fake_enc))
			tf.summary.scalar('real', tf.reduce_mean(self.real))
			tf.summary.scalar('learning_rate', self.lr)
			## cost summaries		
			tf.summary.scalar('cost_vae_gan', self.vae_gan_cost)
			tf.summary.scalar('cost_recon_img', self.recon_img_cost)
			tf.summary.scalar('cost_gan_cost', self.gan_cost)
			tf.summary.scalar('cost_recon_latent', self.recon_latent_cost)
			tf.summary.scalar('cost_kl_div', self.kl_div_cost)
			tf.summary.scalar('cost_final', self.cost)
			## Merge Summaries
			self.merge_op = tf.summary.merge_all()

		summary_create(self)
Example #19
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join(savedir, FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join(savedir, FLAGS.decoder_save)))
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    savedir = 'checkpoints_%d' % (FLAGS.batch_size)
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    # Creating data indices for training and validation splits:
    dataset_size = len(mnist)
    indices = list(range(dataset_size))
    split = 10000
    np.random.seed(0)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_mnist, val_mnist = torch.utils.data.random_split(
        mnist, [dataset_size - split, split])

    # Creating PT data samplers and loaders:
    weights_train = torch.ones(len(mnist))
    weights_test = torch.ones(len(mnist))
    weights_train[val_mnist.indices] = 0
    weights_test[train_mnist.indices] = 0
    counts = torch.zeros(10)
    for i in range(10):
        idx_label = mnist.targets[train_mnist.indices].eq(i)
        counts[i] = idx_label.sum()
    max = float(counts.max())
    sum_counts = float(counts.sum())
    for i in range(10):
        idx_label = mnist.targets[train_mnist.indices].eq(
            i).nonzero().squeeze()
        weights_train[train_mnist.indices[idx_label]] = (sum_counts /
                                                         counts[i])

    train_sampler = SubsetRandomSampler(train_mnist.indices)
    valid_sampler = SubsetRandomSampler(val_mnist.indices)
    kwargs = {'num_workers': 1, 'pin_memory': True} if FLAGS.cuda else {}
    loader = DataLoader(mnist,
                        batch_size=FLAGS.batch_size,
                        sampler=train_sampler,
                        **kwargs)
    valid_loader = DataLoader(mnist,
                              batch_size=FLAGS.batch_size,
                              sampler=valid_sampler,
                              **kwargs)
    monitor = torch.zeros(FLAGS.end_epoch - FLAGS.start_epoch, 4)
    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )
        elbo_epoch = 0
        term1_epoch = 0
        term2_epoch = 0
        term3_epoch = 0
        for it, (image_batch, labels_batch) in enumerate(loader):
            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X = image_batch.cuda().detach().clone()
            elbo, reconstruction_proba, style_kl_divergence_loss, class_kl_divergence_loss = process(
                FLAGS, X, labels_batch, encoder, decoder)
            (-elbo).backward()
            auto_encoder_optimizer.step()
            elbo_epoch += elbo
            term1_epoch += reconstruction_proba
            term2_epoch += style_kl_divergence_loss
            term3_epoch += class_kl_divergence_loss

        print("Elbo epoch %.2f" % (elbo_epoch / (it + 1)))
        print("Rec. Proba %.2f" % (term1_epoch / (it + 1)))
        print("KL style %.2f" % (term2_epoch / (it + 1)))
        print("KL content %.2f" % (term3_epoch / (it + 1)))
        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            monitor[epoch, :] = eval(FLAGS, valid_loader, encoder, decoder)
            torch.save(
                encoder.state_dict(),
                os.path.join(savedir, FLAGS.encoder_save + '_e%d' % epoch))
            torch.save(
                decoder.state_dict(),
                os.path.join(savedir, FLAGS.decoder_save + '_e%d' % epoch))
            print("VAL elbo %.2f" % (monitor[epoch, 0]))
            print("VAL Rec. Proba %.2f" % (monitor[epoch, 1]))
            print("VAL KL style %.2f" % (monitor[epoch, 2]))
            print("VAL KL content %.2f" % (monitor[epoch, 3]))

            torch.save(monitor, os.path.join(savedir, 'monitor_e%d' % epoch))
Example #20
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    discriminator = Discriminator()
    discriminator.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        raise Exception('This is not implemented')
        encoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))

    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels, FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)

    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    adversarial_loss = nn.BCELoss()

    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

        cross_entropy_loss.cuda()
        adversarial_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()

    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(
        list(encoder.parameters()) + list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    reverse_cycle_optimizer = optim.Adam(
        list(encoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    generator_optimizer = optim.Adam(
        list(decoder.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    discriminator_optimizer = optim.Adam(
        list(discriminator.parameters()),
        lr=FLAGS.initial_learning_rate,
        betas=(FLAGS.beta_1, FLAGS.beta_2)
    )

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer, step_size=80, gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(reverse_cycle_optimizer, step_size=80, gamma=0.1)
    generator_scheduler = optim.lr_scheduler.StepLR(generator_optimizer, step_size=80, gamma=0.1)
    discriminator_scheduler = optim.lr_scheduler.StepLR(discriminator_optimizer, step_size=80, gamma=0.1)

    # Used later to define discriminator ground truths
    Tensor = torch.cuda.FloatTensor if FLAGS.cuda else torch.FloatTensor

    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            headers = ['Epoch', 'Iteration', 'Reconstruction_loss', 'KL_divergence_loss', 'Reverse_cycle_loss']

            if FLAGS.forward_gan:
              headers.extend(['Generator_forward_loss', 'Discriminator_forward_loss'])

            if FLAGS.reverse_gan:
              headers.extend(['Generator_reverse_loss', 'Discriminator_reverse_loss'])

            log.write('\t'.join(headers) + '\n')

    # load data set and create data loader instance
    print('Loading CIFAR paired dataset...')
    paired_cifar = CIFAR_Paired(root='cifar', download=True, train=True, transform=transform_config)
    loader = cycle(DataLoader(paired_cifar, batch_size=FLAGS.batch_size, shuffle=True, num_workers=0, drop_last=True))

    # Save a batch of images to use for visualization
    image_sample_1, image_sample_2, _ = next(loader)
    image_sample_3, _, _ = next(loader)

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print('Epoch #' + str(epoch) + '..........................................................................')

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()
        generator_scheduler.step()
        discriminator_scheduler.step()

        for iteration in range(int(len(paired_cifar) / FLAGS.batch_size)):
            # Adversarial ground truths
            valid = Variable(Tensor(FLAGS.batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(FLAGS.batch_size, 1).fill_(0.0), requires_grad=False)

            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(Variable(X_1))
            style_latent_space_1 = reparameterize(training=True, mu=style_mu_1, logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) - style_logvar_1.exp())
            )
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(Variable(X_2))
            style_latent_space_2 = reparameterize(training=True, mu=style_mu_2, logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                - 0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) - style_logvar_2.exp())
            )
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels * FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2, class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (reconstruction_error_1 + reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # A-1. Discriminator training during forward cycle
            if FLAGS.forward_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_f_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_f_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_f_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_f_loss.backward()

              discriminator_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space), class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space), class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False, mu=style_mu_2, logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            # B-1. Discriminator training during reverse cycle
            if FLAGS.reverse_gan:
              # Training generator
              generator_optimizer.zero_grad()

              g_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), valid)
              g_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), valid)

              gen_r_loss = (g_loss_1 + g_loss_2) / 2.0
              gen_r_loss.backward()

              generator_optimizer.step()

              # Training discriminator
              discriminator_optimizer.zero_grad()

              real_loss_1 = adversarial_loss(discriminator(Variable(X_1)), valid)
              real_loss_2 = adversarial_loss(discriminator(Variable(X_2)), valid)
              fake_loss_1 = adversarial_loss(discriminator(Variable(reconstructed_X_1)), fake)
              fake_loss_2 = adversarial_loss(discriminator(Variable(reconstructed_X_2)), fake)

              dis_r_loss = (real_loss_1 + real_loss_2 + fake_loss_1 + fake_loss_2) / 4.0
              dis_r_loss.backward()

              discriminator_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' + str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' + str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' + str(reverse_cycle_loss.data.storage().tolist()[0]))

                if FLAGS.forward_gan:
                  print('Generator F loss: ' + str(gen_f_loss.data.storage().tolist()[0]))
                  print('Discriminator F loss: ' + str(dis_f_loss.data.storage().tolist()[0]))

                if FLAGS.reverse_gan:
                  print('Generator R loss: ' + str(gen_r_loss.data.storage().tolist()[0]))
                  print('Discriminator R loss: ' + str(dis_r_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                row = []

                row.append(epoch)
                row.append(iteration)
                row.append(reconstruction_error.data.storage().tolist()[0])
                row.append(kl_divergence_error.data.storage().tolist()[0])
                row.append(reverse_cycle_loss.data.storage().tolist()[0])

                if FLAGS.forward_gan:
                  row.append(gen_f_loss.data.storage().tolist()[0])
                  row.append(dis_f_loss.data.storage().tolist()[0])

                if FLAGS.reverse_gan:
                  row.append(gen_r_loss.data.storage().tolist()[0])
                  row.append(dis_r_loss.data.storage().tolist()[0])

                row = [str(x) for x in row]
                log.write('\t'.join(row) + '\n')

            # write to tensorboard
            writer.add_scalar('Reconstruction loss', reconstruction_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('KL-Divergence loss', kl_divergence_error.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar('Reverse cycle loss', reverse_cycle_loss.data.storage().tolist()[0],
                              epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.forward_gan:
              writer.add_scalar('Generator F loss', gen_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator F loss', dis_f_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

            if FLAGS.reverse_gan:
              writer.add_scalar('Generator R loss', gen_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)
              writer.add_scalar('Discriminator R loss', dis_r_loss.data.storage().tolist()[0],
                                epoch * (int(len(paired_cifar) / FLAGS.batch_size) + 1) + iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(), os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(), os.path.join('checkpoints', FLAGS.decoder_save))

            """
            save reconstructed images and style swapped image generations to check progress
            """

            X_1.copy_(image_sample_1)
            X_2.copy_(image_sample_2)
            X_3.copy_(image_sample_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False, mu=style_mu_1, logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False, mu=style_mu_3, logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1, class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3, class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              image_batch = np.concatenate((image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_x = np.concatenate((reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x, name=str(epoch) + '_target', save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              style_batch = np.concatenate((style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            if FLAGS.num_channels == 1:
              reconstructed_style = np.concatenate((reconstructed_style, reconstructed_style, reconstructed_style), axis=3)
            imshow_grid(reconstructed_style, name=str(epoch) + '_style_target', save=True)
Example #21
0
    def __init__(self, hyperparameters):
        super(LSGANs_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.encoder = Encoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.decoder = Decoder(hyperparameters['input_dim_a'],
                               hyperparameters['gen'])
        self.dis_a = Discriminator()
        self.dis_b = Discriminator()
        self.interp_net_ab = Interpolator()
        self.interp_net_ba = Interpolator()
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        enc_params = list(self.encoder.parameters())
        dec_params = list(self.decoder.parameters())
        dis_a_params = list(self.dis_a.parameters())
        dis_b_params = list(self.dis_b.parameters())
        interperlator_ab_params = list(self.interp_net_ab.parameters())
        interperlator_ba_params = list(self.interp_net_ba.parameters())

        self.enc_opt = torch.optim.Adam(
            [p for p in enc_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dec_opt = torch.optim.Adam(
            [p for p in dec_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_a_opt = torch.optim.Adam(
            [p for p in dis_a_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_b_opt = torch.optim.Adam(
            [p for p in dis_b_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ab_opt = torch.optim.Adam(
            [p for p in interperlator_ab_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.interp_ba_opt = torch.optim.Adam(
            [p for p in interperlator_ba_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.enc_scheduler = get_scheduler(self.enc_opt, hyperparameters)
        self.dec_scheduler = get_scheduler(self.dec_opt, hyperparameters)
        self.dis_a_scheduler = get_scheduler(self.dis_a_opt, hyperparameters)
        self.dis_b_scheduler = get_scheduler(self.dis_b_opt, hyperparameters)
        self.interp_ab_scheduler = get_scheduler(self.interp_ab_opt,
                                                 hyperparameters)
        self.interp_ba_scheduler = get_scheduler(self.interp_ba_opt,
                                                 hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] +
                                  '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

        self.total_loss = 0
        self.best_iter = 0
        self.perceptural_loss = Perceptural_loss()
Example #22
0
def test(opt):
    #### mkdir
    des_pth = os.path.join('results', opt.name)
    if os.path.exists(os.path.join(des_pth)) is not True:
        os.mkdir(des_pth)
    src_pth = os.path.join(opt.checkpoints, opt.name)

    models_name = os.listdir(src_pth)
    models_name.remove('images')
    models_name.remove('records.txt')
    models_name.sort(key=lambda x: int(x[6:9]))
    target = int(models_name[-1][6:9])

    #### device
    device = torch.device('cuda:{}'.format(opt.gpu_id) if opt.gpu_id >= 0 else torch.device('cpu'))

    #### data
    data_loader = UnAlignedDataLoader()
    data_loader.initialize(opt)
    data_set = data_loader.load_data()

    #### networks
    ## initialize
    E_a2b = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9)
    G_b = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type)
    E_b2a = Encoder(input_nc=opt.input_nc, ngf=opt.ngf, norm_type=opt.norm_type, use_dropout=not opt.no_dropout, n_blocks=9)
    G_a = Decoder(output_nc=opt.output_nc, ngf=opt.ngf, norm_type=opt.norm_type)

    ## load in models
    E_a2b.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-E_a2b.pth'%target)))
    G_b.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-G_b.pth'%target)))
    E_b2a.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-E_b2a.pth' % target)))
    G_a.load_state_dict(torch.load(os.path.join(src_pth, 'epoch_%3d-G_a.pth' % target)))

    E_a2b = E_a2b.to(device)
    G_b = G_b.to(device)
    E_b2a = E_b2a.to(device)
    G_a = G_a.to(device)

    for i, data in enumerate(data_set):
        real_A = data['A'].to(device)
        real_B = data['B'].to(device)

        fake_B = G_b(E_a2b(real_A))
        fake_A = G_a(E_b2a(real_B))

        ## visualize
        if opt.gpu_id >= 0:
            fake_B = fake_B.cpu().data
            fake_A = fake_A.cpu().data

            real_A = real_A.cpu()
            real_B = real_B.cpu()

        for j in range(opt.batch_size):
            fake_b = tensor2image_RGB(fake_B[j, ...])
            fake_a = tensor2image_RGB(fake_A[j, ...])

            real_a = tensor2image_RGB(real_A[j, ...])
            real_b = tensor2image_RGB(real_B[j, ...])

            plt.subplot(221), plt.title("real_A"), plt.imshow(real_a)
            plt.subplot(222), plt.title("fake_B"), plt.imshow(fake_b)
            plt.subplot(223), plt.title("real_B"), plt.imshow(real_b)
            plt.subplot(224), plt.title("fake_A"), plt.imshow(fake_a)

            plt.savefig(os.path.join(des_pth, '%06d-%02d.jpg'%(i, j)))
        #break #-> debug

    print("≧◔◡◔≦ Congratulation! Successfully finishing the testing!")
Example #23
0
def main(args):
    train, test = chainer.datasets.get_cifar10()
    test_iter = iterators.SerialIterator(test, 1, shuffle=False, repeat=False)
    train_iter = iterators.SerialIterator(train,
                                          1,
                                          shuffle=False,
                                          repeat=False)

    encoder = Encoder()
    serializers.load_npz(args.input, encoder)

    if args.device >= 0:
        encoder.to_gpu(args.device)
    else:
        raise ValueError("Currently only GPU mode works, sorry!")

    _t = -1
    while _t != args.label:
        test_batch = test_iter.next()
        x, t = concat_examples(test_batch, args.device)
        key, f = encoder(x)
        _t = t.get().tolist()[0]

    distance = []
    features = []
    truth = []
    image = []

    c = 0
    with chainer.using_config('train', False):
        #for i in range(500):
        #train_batch = test_iter.next()
        for train_batch in test_iter:
            _x, _t = concat_examples(train_batch, args.device)
            _y, _f = encoder(_x)

            dist = F.mean_absolute_error(key,
                                         _y).data.get().flatten().tolist()[0]
            true = _t.get().tolist()[0]
            pic = _x.get()[0].transpose(1, 2, 0)
            #print(dist, true, pic.shape)

            distance.append(dist)
            truth.append(true)
            image.append(pic)
            c += 1

            if c % 1000 == 0:
                print(c)

    idx = sorted(range(len(distance)), key=distance.__getitem__)

    for i in idx[:10]:
        print(distance[i], truth[i])

    print("original", t)

    img = x.get()[0].transpose(1, 2, 0)

    middle_row = np.concatenate([image[i] for i in idx[:11]], axis=1)
    top_row = np.concatenate(
        ((img, ) + tuple([np.ones_like(img) for i in range(10)])), axis=1)
    bottom_row = np.concatenate(tuple([image[i] for i in idx[-11:]]), axis=1)

    _img = np.concatenate((top_row, middle_row, bottom_row), axis=0)
    plt.imshow(_img)

    plt.axis('off')
    plt.show()
    plt.imsave(os.path.join(args.output, "img.png"), _img)
Example #24
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 1, FLAGS.image_size,
                          FLAGS.image_size)
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        X = X.cuda()
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST dataset...')
    mnist = datasets.MNIST(root='mnist',
                           download=True,
                           train=True,
                           transform=transform_config)
    loader = cycle(
        DataLoader(mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        for iteration in range(int(len(mnist) / FLAGS.batch_size)):
            # load a mini-batch
            image_batch, labels_batch = next(loader)

            # set zero_grad for the optimizer
            auto_encoder_optimizer.zero_grad()

            X.copy_(image_batch)

            style_mu, style_logvar, class_mu, class_logvar = encoder(
                Variable(X))
            grouped_mu, grouped_logvar = accumulate_group_evidence(
                class_mu.data, class_logvar.data, labels_batch, FLAGS.cuda)

            # kl-divergence error for style latent space
            style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                 style_logvar.exp()))
            style_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            style_kl_divergence_loss.backward(retain_graph=True)

            # kl-divergence error for class latent space
            class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                 grouped_logvar.exp()))
            class_kl_divergence_loss /= (FLAGS.batch_size *
                                         FLAGS.num_channels *
                                         FLAGS.image_size * FLAGS.image_size)
            class_kl_divergence_loss.backward(retain_graph=True)

            # reconstruct samples
            """
            sampling from group mu and logvar for each image in mini-batch differently makes
            the decoder consider class latent embeddings as random noise and ignore them 
            """
            style_latent_embeddings = reparameterize(training=True,
                                                     mu=style_mu,
                                                     logvar=style_logvar)
            class_latent_embeddings = group_wise_reparameterize(
                training=True,
                mu=grouped_mu,
                logvar=grouped_logvar,
                labels_batch=labels_batch,
                cuda=FLAGS.cuda)

            reconstructed_images = decoder(style_latent_embeddings,
                                           class_latent_embeddings)

            reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_images, Variable(X))
            reconstruction_error.backward()

            auto_encoder_optimizer.step()

            if (iteration + 1) % 50 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('Style KL-Divergence loss: ' +
                      str(style_kl_divergence_loss.data.storage().tolist()[0]))
                print('Class KL-Divergence loss: ' +
                      str(class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    style_kl_divergence_loss.data.storage().tolist()[0],
                    class_kl_divergence_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Style KL-Divergence loss',
                style_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)
            writer.add_scalar(
                'Class KL-Divergence loss',
                class_kl_divergence_loss.data.storage().tolist()[0],
                epoch * (int(len(mnist) / FLAGS.batch_size) + 1) + iteration)

        # save checkpoints after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
Example #25
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """
    X = torch.FloatTensor(FLAGS.batch_size, 784)
    '''
    run on GPU if GPU is available
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder.to(device=device)
    decoder.to(device=device)
    X = X.to(device=device)
    """
    optimizer definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))
    """
    
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tStyle_KL_divergence_loss\tClass_KL_divergence_loss\n'
            )

    # load data set and create data loader instance
    dirs = os.listdir(os.path.join(os.getcwd(), 'data'))
    print('Loading double multivariate normal time series data...')
    for dsname in dirs:
        params = dsname.split('_')
        if params[2] in ('theta=-1'):
            print('Running dataset ', dsname)
            ds = DoubleMulNormal(dsname)
            # ds = experiment3(1000, 50, 3)
            loader = cycle(
                DataLoader(ds,
                           batch_size=FLAGS.batch_size,
                           shuffle=True,
                           drop_last=True))

            # initialize summary writer
            writer = SummaryWriter()

            for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
                print()
                print(
                    'Epoch #' + str(epoch) +
                    '........................................................')

                # the total loss at each epoch after running iterations of batches
                total_loss = 0

                for iteration in range(int(len(ds) / FLAGS.batch_size)):
                    # load a mini-batch
                    image_batch, labels_batch = next(loader)

                    # set zero_grad for the optimizer
                    auto_encoder_optimizer.zero_grad()

                    X.copy_(image_batch)

                    style_mu, style_logvar, class_mu, class_logvar = encoder(
                        Variable(X))
                    grouped_mu, grouped_logvar = accumulate_group_evidence(
                        class_mu.data, class_logvar.data, labels_batch,
                        FLAGS.cuda)

                    # kl-divergence error for style latent space
                    style_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) -
                                         style_logvar.exp()))
                    style_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    style_kl_divergence_loss.backward(retain_graph=True)

                    # kl-divergence error for class latent space
                    class_kl_divergence_loss = FLAGS.kl_divergence_coef * (
                        -0.5 *
                        torch.sum(1 + grouped_logvar - grouped_mu.pow(2) -
                                  grouped_logvar.exp()))
                    class_kl_divergence_loss /= (FLAGS.batch_size *
                                                 FLAGS.num_channels *
                                                 FLAGS.image_size *
                                                 FLAGS.image_size)
                    class_kl_divergence_loss.backward(retain_graph=True)

                    # reconstruct samples
                    """
                    sampling from group mu and logvar for each image in mini-batch differently makes
                    the decoder consider class latent embeddings as random noise and ignore them 
                    """
                    style_latent_embeddings = reparameterize(
                        training=True, mu=style_mu, logvar=style_logvar)
                    class_latent_embeddings = group_wise_reparameterize(
                        training=True,
                        mu=grouped_mu,
                        logvar=grouped_logvar,
                        labels_batch=labels_batch,
                        cuda=FLAGS.cuda)

                    reconstructed_images = decoder(style_latent_embeddings,
                                                   class_latent_embeddings)

                    reconstruction_error = FLAGS.reconstruction_coef * mse_loss(
                        reconstructed_images, Variable(X))
                    reconstruction_error.backward()

                    total_loss += style_kl_divergence_loss + class_kl_divergence_loss + reconstruction_error

                    auto_encoder_optimizer.step()

                    if (iteration + 1) % 50 == 0:
                        print('\tIteration #' + str(iteration))
                        print('Reconstruction loss: ' + str(
                            reconstruction_error.data.storage().tolist()[0]))
                        print('Style KL loss: ' +
                              str(style_kl_divergence_loss.data.storage().
                                  tolist()[0]))
                        print('Class KL loss: ' +
                              str(class_kl_divergence_loss.data.storage().
                                  tolist()[0]))

                    # write to log
                    with open(FLAGS.log_file, 'a') as log:
                        log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                            epoch, iteration,
                            reconstruction_error.data.storage().tolist()[0],
                            style_kl_divergence_loss.data.storage().tolist()
                            [0],
                            class_kl_divergence_loss.data.storage().tolist()
                            [0]))

                    # write to tensorboard
                    writer.add_scalar(
                        'Reconstruction loss',
                        reconstruction_error.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Style KL-Divergence loss',
                        style_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)
                    writer.add_scalar(
                        'Class KL-Divergence loss',
                        class_kl_divergence_loss.data.storage().tolist()[0],
                        epoch * (int(len(ds) / FLAGS.batch_size) + 1) +
                        iteration)

                    if epoch == 0 and (iteration + 1) % 50 == 0:
                        torch.save(
                            encoder.state_dict(),
                            os.path.join('checkpoints', 'encoder_' + dsname))
                        torch.save(
                            decoder.state_dict(),
                            os.path.join('checkpoints', 'decoder_' + dsname))

                # save checkpoints after every 10 epochs
                if (epoch + 1) % 10 == 0 or (epoch + 1) == FLAGS.end_epoch:
                    torch.save(
                        encoder.state_dict(),
                        os.path.join('checkpoints', 'encoder_' + dsname))
                    torch.save(
                        decoder.state_dict(),
                        os.path.join('checkpoints', 'decoder_' + dsname))

                print('Total loss at current epoch: ', total_loss.item())
    return z_collection


if (__name__ == '__main__'):

    # model definition
    BATCH_SIZE = 1

    dataset = load_dataset()
    loader = cycle(
        DataLoader(dataset,
                   batch_size=BATCH_SIZE,
                   shuffle=True,
                   drop_last=True))

    encoder = Encoder()
    encoder.apply(weights_init)

    decoder = Decoder()
    decoder.apply(weights_init)

    encoder.load_state_dict(
        torch.load(os.path.join('checkpoints', ENCODER_SAVE)))
    decoder.load_state_dict(
        torch.load(os.path.join('checkpoints', DECODER_SAVE)))

    encoder.eval()
    decoder.eval()

    prediction_model = Prediction_Model()
    prediction_model.apply(weights_init)
Example #27
0
    def __init__(self,
                 x_dim,
                 z_dim,
                 enc_architecture,
                 gen_architecture,
                 disc_architecture,
                 folder="./VAEGAN"):
        super(VAEGAN, self).__init__(
            x_dim, z_dim,
            [enc_architecture, gen_architecture, disc_architecture], folder)

        self._enc_architecture = self._architectures[0]
        self._gen_architecture = self._architectures[1]
        self._disc_architecture = self._architectures[2]

        ################# Define architecture
        last_layer_mean = [
            logged_dense, {
                "units": z_dim,
                "activation": tf.identity,
                "name": "Mean"
            }
        ]
        self._encoder_mean = Encoder(self._enc_architecture +
                                     [last_layer_mean],
                                     name="Encoder")
        last_layer_std = [
            logged_dense, {
                "units": z_dim,
                "activation": tf.identity,
                "name": "Std"
            }
        ]
        self._encoder_std = Encoder(self._enc_architecture + [last_layer_std],
                                    name="Encoder")

        self._gen_architecture[-1][1]["name"] = "Output"
        self._generator = Generator(self._gen_architecture, name="Generator")

        self._disc_architecture.append(
            [tf.layers.flatten, {
                "name": "Flatten"
            }])
        self._disc_architecture.append([
            logged_dense, {
                "units": 1,
                "activation": tf.nn.sigmoid,
                "name": "Output"
            }
        ])
        self._discriminator = Discriminator(self._disc_architecture,
                                            name="Discriminator")

        self._nets = [self._encoder_mean, self._generator, self._discriminator]

        ################# Connect inputs and networks
        self._mean_layer = self._encoder_mean.generate_net(self._X_input)
        self._std_layer = self._encoder_std.generate_net(self._X_input)

        self._output_enc_with_noise = self._mean_layer + tf.exp(
            0.5 * self._std_layer) * self._Z_input

        self._output_gen = self._generator.generate_net(
            self._output_enc_with_noise)
        self._output_gen_from_encoding = self._generator.generate_net(
            self._Z_input)

        assert self._output_gen.get_shape()[1:] == x_dim, (
            "Generator output must have shape of x_dim. Given: {}. Expected: {}."
            .format(self._output_gen.get_shape(), x_dim))

        self._output_disc_real = self._discriminator.generate_net(
            self._X_input)
        self._output_disc_fake_from_real = self._discriminator.generate_net(
            self._output_gen)
        self._output_disc_fake_from_latent = self._discriminator.generate_net(
            self._output_gen_from_encoding)

        ################# Finalize
        self._init_folders()
        self._verify_init()
Example #28
0
    def __init__(
        self,
        x_dim,
        y_dim,
        z_dim,
        gen_architecture,
        adversarial_architecture,
        folder="./CGAN",
        append_y_at_every_layer=None,
        is_patchgan=False,
        is_wasserstein=False,
        aux_architecture=None,
    ):
        architectures = [gen_architecture, adversarial_architecture]
        self._is_cycle_consistent = False
        if aux_architecture is not None:
            architectures.append(aux_architecture)
            self._is_cycle_consistent = True
        super(CGAN,
              self).__init__(x_dim=x_dim,
                             y_dim=y_dim,
                             z_dim=z_dim,
                             architectures=architectures,
                             folder=folder,
                             append_y_at_every_layer=append_y_at_every_layer)

        self._gen_architecture = self._architectures[0]
        self._adversarial_architecture = self._architectures[1]
        self._is_patchgan = is_patchgan
        self._is_wasserstein = is_wasserstein
        self._is_feature_matching = False

        ################# Define architecture
        if self._is_patchgan:
            f_xy = self._adversarial_architecture[-1][-1]["filters"]
            assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format(
                f_xy)

            a_xy = self._adversarial_architecture[-1][-1]["activation"]
            if self._is_wasserstein:
                assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format(
                    a_xy)
            else:
                assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format(
                    a_xy)
        else:
            self._adversarial_architecture.append(
                [tf.layers.flatten, {
                    "name": "Flatten"
                }])
            if self._is_wasserstein:
                self._adversarial_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.identity,
                        "name": "Output"
                    }
                ])
            else:
                self._adversarial_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.nn.sigmoid,
                        "name": "Output"
                    }
                ])
        self._gen_architecture[-1][1]["name"] = "Output"

        self._generator = ConditionalGenerator(self._gen_architecture,
                                               name="Generator")
        self._adversarial = Critic(self._adversarial_architecture,
                                   name="Adversarial")

        self._nets = [self._generator, self._adversarial]

        ################# Connect inputs and networks
        self._output_gen = self._generator.generate_net(
            self._mod_Z_input,
            append_elements_at_every_layer=self._append_at_every_layer,
            tf_trainflag=self._is_training)

        with tf.name_scope("InputsAdversarial"):
            if len(self._x_dim) == 1:
                self._input_real = tf.concat(
                    axis=1, values=[self._X_input, self._Y_input], name="real")
                self._input_fake = tf.concat(
                    axis=1,
                    values=[self._output_gen, self._Y_input],
                    name="fake")
            else:
                self._input_real = image_condition_concat(
                    inputs=self._X_input, condition=self._Y_input, name="real")
                self._input_fake = image_condition_concat(
                    inputs=self._output_gen,
                    condition=self._Y_input,
                    name="fake")

        self._output_adversarial_real = self._adversarial.generate_net(
            self._input_real, tf_trainflag=self._is_training)
        self._output_adversarial_fake = self._adversarial.generate_net(
            self._input_fake, tf_trainflag=self._is_training)

        assert self._output_gen.get_shape()[1:] == x_dim, (
            "Output of generator is {}, but x_dim is {}.".format(
                self._output_gen.get_shape(), x_dim))

        ################# Auxiliary network for cycle consistency
        if self._is_cycle_consistent:
            self._auxiliary = Encoder(self._architectures[2], name="Auxiliary")
            self._output_auxiliary = self._auxiliary.generate_net(
                self._output_gen, tf_trainflag=self._is_training)
            assert self._output_auxiliary.get_shape().as_list(
            ) == self._mod_Z_input.get_shape().as_list(), (
                "Wrong shape for auxiliary vs. mod Z: {} vs {}.".format(
                    self._output_auxiliary.get_shape(),
                    self._mod_Z_input.get_shape()))
            self._nets.append(self._auxiliary)

        ################# Finalize
        self._init_folders()
        self._verify_init()

        if self._is_patchgan:
            print("PATCHGAN chosen with output: {}.".format(
                self._output_adversarial_real.shape))
Example #29
0
def training_procedure(FLAGS):
    """
    model definition
    """
    encoder = Encoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    encoder.apply(weights_init)

    decoder = Decoder(style_dim=FLAGS.style_dim, class_dim=FLAGS.class_dim)
    decoder.apply(weights_init)

    # load saved models if load_saved flag is true
    if FLAGS.load_saved:
        encoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.encoder_save)))
        decoder.load_state_dict(
            torch.load(os.path.join('checkpoints', FLAGS.decoder_save)))
    """
    variable definition
    """

    X_1 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_2 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)
    X_3 = torch.FloatTensor(FLAGS.batch_size, FLAGS.num_channels,
                            FLAGS.image_size, FLAGS.image_size)

    style_latent_space = torch.FloatTensor(FLAGS.batch_size, FLAGS.style_dim)
    """
    loss definitions
    """
    cross_entropy_loss = nn.CrossEntropyLoss()
    '''
    add option to run on GPU
    '''
    if FLAGS.cuda:
        encoder.cuda()
        decoder.cuda()

        cross_entropy_loss.cuda()

        X_1 = X_1.cuda()
        X_2 = X_2.cuda()
        X_3 = X_3.cuda()

        style_latent_space = style_latent_space.cuda()
    """
    optimizer and scheduler definition
    """
    auto_encoder_optimizer = optim.Adam(list(encoder.parameters()) +
                                        list(decoder.parameters()),
                                        lr=FLAGS.initial_learning_rate,
                                        betas=(FLAGS.beta_1, FLAGS.beta_2))

    reverse_cycle_optimizer = optim.Adam(list(encoder.parameters()),
                                         lr=FLAGS.initial_learning_rate,
                                         betas=(FLAGS.beta_1, FLAGS.beta_2))

    # divide the learning rate by a factor of 10 after 80 epochs
    auto_encoder_scheduler = optim.lr_scheduler.StepLR(auto_encoder_optimizer,
                                                       step_size=80,
                                                       gamma=0.1)
    reverse_cycle_scheduler = optim.lr_scheduler.StepLR(
        reverse_cycle_optimizer, step_size=80, gamma=0.1)
    """
    training
    """
    if torch.cuda.is_available() and not FLAGS.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')

    if not os.path.exists('reconstructed_images'):
        os.makedirs('reconstructed_images')

    # load_saved is false when training is started from 0th iteration
    if not FLAGS.load_saved:
        with open(FLAGS.log_file, 'w') as log:
            log.write(
                'Epoch\tIteration\tReconstruction_loss\tKL_divergence_loss\tReverse_cycle_loss\n'
            )

    # load data set and create data loader instance
    print('Loading MNIST paired dataset...')
    paired_mnist = MNIST_Paired(root='mnist',
                                download=True,
                                train=True,
                                transform=transform_config)
    loader = cycle(
        DataLoader(paired_mnist,
                   batch_size=FLAGS.batch_size,
                   shuffle=True,
                   num_workers=0,
                   drop_last=True))

    # initialize summary writer
    writer = SummaryWriter()

    for epoch in range(FLAGS.start_epoch, FLAGS.end_epoch):
        print('')
        print(
            'Epoch #' + str(epoch) +
            '..........................................................................'
        )

        # update the learning rate scheduler
        auto_encoder_scheduler.step()
        reverse_cycle_scheduler.step()

        for iteration in range(int(len(paired_mnist) / FLAGS.batch_size)):
            # A. run the auto-encoder reconstruction
            image_batch_1, image_batch_2, _ = next(loader)

            auto_encoder_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_mu_1, style_logvar_1, class_latent_space_1 = encoder(
                Variable(X_1))
            style_latent_space_1 = reparameterize(training=True,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            kl_divergence_loss_1 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_1 - style_mu_1.pow(2) -
                                 style_logvar_1.exp()))
            kl_divergence_loss_1 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_1.backward(retain_graph=True)

            style_mu_2, style_logvar_2, class_latent_space_2 = encoder(
                Variable(X_2))
            style_latent_space_2 = reparameterize(training=True,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            kl_divergence_loss_2 = FLAGS.kl_divergence_coef * (
                -0.5 * torch.sum(1 + style_logvar_2 - style_mu_2.pow(2) -
                                 style_logvar_2.exp()))
            kl_divergence_loss_2 /= (FLAGS.batch_size * FLAGS.num_channels *
                                     FLAGS.image_size * FLAGS.image_size)
            kl_divergence_loss_2.backward(retain_graph=True)

            reconstructed_X_1 = decoder(style_latent_space_1,
                                        class_latent_space_2)
            reconstructed_X_2 = decoder(style_latent_space_2,
                                        class_latent_space_1)

            reconstruction_error_1 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_1, Variable(X_1))
            reconstruction_error_1.backward(retain_graph=True)

            reconstruction_error_2 = FLAGS.reconstruction_coef * mse_loss(
                reconstructed_X_2, Variable(X_2))
            reconstruction_error_2.backward()

            reconstruction_error = (
                reconstruction_error_1 +
                reconstruction_error_2) / FLAGS.reconstruction_coef
            kl_divergence_error = (kl_divergence_loss_1 + kl_divergence_loss_2
                                   ) / FLAGS.kl_divergence_coef

            auto_encoder_optimizer.step()

            # B. reverse cycle
            image_batch_1, _, __ = next(loader)
            image_batch_2, _, __ = next(loader)

            reverse_cycle_optimizer.zero_grad()

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)

            style_latent_space.normal_(0., 1.)

            _, __, class_latent_space_1 = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))

            reconstructed_X_1 = decoder(Variable(style_latent_space),
                                        class_latent_space_1.detach())
            reconstructed_X_2 = decoder(Variable(style_latent_space),
                                        class_latent_space_2.detach())

            style_mu_1, style_logvar_1, _ = encoder(reconstructed_X_1)
            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)

            style_mu_2, style_logvar_2, _ = encoder(reconstructed_X_2)
            style_latent_space_2 = reparameterize(training=False,
                                                  mu=style_mu_2,
                                                  logvar=style_logvar_2)

            reverse_cycle_loss = FLAGS.reverse_cycle_coef * l1_loss(
                style_latent_space_1, style_latent_space_2)
            reverse_cycle_loss.backward()
            reverse_cycle_loss /= FLAGS.reverse_cycle_coef

            reverse_cycle_optimizer.step()

            if (iteration + 1) % 10 == 0:
                print('')
                print('Epoch #' + str(epoch))
                print('Iteration #' + str(iteration))

                print('')
                print('Reconstruction loss: ' +
                      str(reconstruction_error.data.storage().tolist()[0]))
                print('KL-Divergence loss: ' +
                      str(kl_divergence_error.data.storage().tolist()[0]))
                print('Reverse cycle loss: ' +
                      str(reverse_cycle_loss.data.storage().tolist()[0]))

            # write to log
            with open(FLAGS.log_file, 'a') as log:
                log.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
                    epoch, iteration,
                    reconstruction_error.data.storage().tolist()[0],
                    kl_divergence_error.data.storage().tolist()[0],
                    reverse_cycle_loss.data.storage().tolist()[0]))

            # write to tensorboard
            writer.add_scalar(
                'Reconstruction loss',
                reconstruction_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'KL-Divergence loss',
                kl_divergence_error.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)
            writer.add_scalar(
                'Reverse cycle loss',
                reverse_cycle_loss.data.storage().tolist()[0],
                epoch * (int(len(paired_mnist) / FLAGS.batch_size) + 1) +
                iteration)

        # save model after every 5 epochs
        if (epoch + 1) % 5 == 0 or (epoch + 1) == FLAGS.end_epoch:
            torch.save(encoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.encoder_save))
            torch.save(decoder.state_dict(),
                       os.path.join('checkpoints', FLAGS.decoder_save))
            """
            save reconstructed images and style swapped image generations to check progress
            """
            image_batch_1, image_batch_2, _ = next(loader)
            image_batch_3, _, __ = next(loader)

            X_1.copy_(image_batch_1)
            X_2.copy_(image_batch_2)
            X_3.copy_(image_batch_3)

            style_mu_1, style_logvar_1, _ = encoder(Variable(X_1))
            _, __, class_latent_space_2 = encoder(Variable(X_2))
            style_mu_3, style_logvar_3, _ = encoder(Variable(X_3))

            style_latent_space_1 = reparameterize(training=False,
                                                  mu=style_mu_1,
                                                  logvar=style_logvar_1)
            style_latent_space_3 = reparameterize(training=False,
                                                  mu=style_mu_3,
                                                  logvar=style_logvar_3)

            reconstructed_X_1_2 = decoder(style_latent_space_1,
                                          class_latent_space_2)
            reconstructed_X_3_2 = decoder(style_latent_space_3,
                                          class_latent_space_2)

            # save input image batch
            image_batch = np.transpose(X_1.cpu().numpy(), (0, 2, 3, 1))
            image_batch = np.concatenate(
                (image_batch, image_batch, image_batch), axis=3)
            imshow_grid(image_batch, name=str(epoch) + '_original', save=True)

            # save reconstructed batch
            reconstructed_x = np.transpose(
                reconstructed_X_1_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_x = np.concatenate(
                (reconstructed_x, reconstructed_x, reconstructed_x), axis=3)
            imshow_grid(reconstructed_x,
                        name=str(epoch) + '_target',
                        save=True)

            style_batch = np.transpose(X_3.cpu().numpy(), (0, 2, 3, 1))
            style_batch = np.concatenate(
                (style_batch, style_batch, style_batch), axis=3)
            imshow_grid(style_batch, name=str(epoch) + '_style', save=True)

            # save style swapped reconstructed batch
            reconstructed_style = np.transpose(
                reconstructed_X_3_2.cpu().data.numpy(), (0, 2, 3, 1))
            reconstructed_style = np.concatenate(
                (reconstructed_style, reconstructed_style,
                 reconstructed_style),
                axis=3)
            imshow_grid(reconstructed_style,
                        name=str(epoch) + '_style_target',
                        save=True)
Example #30
0
    def __init__(self,
                 x_dim,
                 y_dim,
                 z_dim,
                 enc_architecture,
                 gen_architecture,
                 adversarial_architecture,
                 folder="./CVAEGAN",
                 is_patchgan=False,
                 is_wasserstein=False):
        super(CVAEGAN, self).__init__(
            x_dim, y_dim,
            [enc_architecture, gen_architecture, adversarial_architecture],
            folder)

        self._z_dim = z_dim
        with tf.name_scope("Inputs"):
            self._Z_input = tf.placeholder(tf.float32,
                                           shape=[None, z_dim],
                                           name="z")

        self._enc_architecture = self._architectures[0]
        self._gen_architecture = self._architectures[1]
        self._adv_architecture = self._architectures[2]

        self._is_patchgan = is_patchgan
        self._is_wasserstein = is_wasserstein
        self._is_feature_matching = False

        ################# Define architecture
        if self._is_patchgan:
            f_xy = self._adv_architecture[-1][-1]["filters"]
            assert f_xy == 1, "If is PatchGAN, last layer of adversarial_XY needs 1 filter. Given: {}.".format(
                f_xy)

            a_xy = self._adv_architecture[-1][-1]["activation"]
            if self._is_wasserstein:
                assert a_xy == tf.identity, "If is PatchGAN, last layer of adversarial needs tf.identity. Given: {}.".format(
                    a_xy)
            else:
                assert a_xy == tf.nn.sigmoid, "If is PatchGAN, last layer of adversarial needs tf.nn.sigmoid. Given: {}.".format(
                    a_xy)
        else:
            self._adv_architecture.append(
                [tf.layers.flatten, {
                    "name": "Flatten"
                }])
            if self._is_wasserstein:
                self._adv_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.identity,
                        "name": "Output"
                    }
                ])
            else:
                self._adv_architecture.append([
                    logged_dense, {
                        "units": 1,
                        "activation": tf.nn.sigmoid,
                        "name": "Output"
                    }
                ])

        last_layers_mean = [[tf.layers.flatten, {
            "name": "flatten"
        }],
                            [
                                logged_dense, {
                                    "units": z_dim,
                                    "activation": tf.identity,
                                    "name": "Mean"
                                }
                            ]]
        self._encoder_mean = Encoder(self._enc_architecture + last_layers_mean,
                                     name="Encoder")
        last_layers_std = [[tf.layers.flatten, {
            "name": "flatten"
        }],
                           [
                               logged_dense, {
                                   "units": z_dim,
                                   "activation": tf.identity,
                                   "name": "Std"
                               }
                           ]]
        self._encoder_std = Encoder(self._enc_architecture + last_layers_std,
                                    name="Encoder")

        self._gen_architecture[-1][1]["name"] = "Output"
        self._generator = Generator(self._gen_architecture, name="Generator")

        self._adversarial = Discriminator(self._adv_architecture,
                                          name="Adversarial")

        self._nets = [self._encoder_mean, self._generator, self._adversarial]

        ################# Connect inputs and networks
        self._mean_layer = self._encoder_mean.generate_net(self._Y_input)
        self._std_layer = self._encoder_std.generate_net(self._Y_input)

        self._output_enc_with_noise = self._mean_layer + tf.exp(
            0.5 * self._std_layer) * self._Z_input
        with tf.name_scope("Inputs"):
            self._gen_input = image_condition_concat(
                inputs=self._X_input,
                condition=self._output_enc_with_noise,
                name="mod_z_real")
            self._gen_input_from_encoding = image_condition_concat(
                inputs=self._X_input, condition=self._Z_input, name="mod_z")
        self._output_gen = self._generator.generate_net(self._gen_input)
        self._output_gen_from_encoding = self._generator.generate_net(
            self._gen_input_from_encoding)
        self._generator._input_dim = z_dim

        assert self._output_gen.get_shape()[1:] == y_dim, (
            "Generator output must have shape of y_dim. Given: {}. Expected: {}."
            .format(self._output_gen.get_shape(), x_dim))

        with tf.name_scope("InputsAdversarial"):
            self._input_real = tf.concat(values=[self._Y_input, self._X_input],
                                         axis=3)
            self._input_fake_from_real = tf.concat(
                values=[self._output_gen, self._X_input], axis=3)
            self._input_fake_from_latent = tf.concat(
                values=[self._output_gen_from_encoding, self._X_input], axis=3)

        self._output_adv_real = self._adversarial.generate_net(
            self._input_real)
        self._output_adv_fake_from_real = self._adversarial.generate_net(
            self._input_fake_from_real)
        self._output_adv_fake_from_latent = self._adversarial.generate_net(
            self._input_fake_from_latent)

        ################# Finalize
        self._init_folders()
        self._verify_init()

        self._output_label_real = tf.placeholder(
            tf.float32, shape=self._output_adv_real.shape, name="label_real")
        self._output_label_fake = tf.placeholder(
            tf.float32,
            shape=self._output_adv_fake_from_real.shape,
            name="label_fake")

        if self._is_patchgan:
            print("PATCHGAN chosen with output: {}.".format(
                self._output_adv_real.shape))