Esempio n. 1
0
    def backward_G_B(self):

        self.loss_G_adversarial_B = loss.adversarial_loss_generator(
            self.fakeBpred,
            self.outputApred,
            method='L2',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_reconstruction_B = loss.reconstruction_loss(
            self.outputB,
            self.realB,
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_mask_B = loss.mask_loss(
            self.maskB,
            threshold=self.loss_config['mask_threshold'],
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_B = self.loss_G_adversarial_B + self.loss_G_reconstruction_B + self.loss_G_mask_B

        if self.loss_config['pl_on']:
            self.loss_G_perceptual_B = loss.perceptual_loss(
                self.realB,
                self.fakeB,
                self.vggface,
                self.vggface_for_pl,
                method='L2',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_B += self.loss_G_perceptual_B

        if self.loss_config['edgeloss_on']:
            self.loss_G_edge_B = loss.edge_loss(
                self.outputB,
                self.realB,
                self.mask_eye_B,
                method='L1',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_B += self.loss_G_edge_B

        if self.loss_config['eyeloss_on']:
            self.loss_G_eye_B = loss.eye_loss(
                self.outputB,
                self.realB,
                self.mask_eye_B,
                method='L1',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_B += self.loss_G_eye_B

        self.loss_G_B.backward(retain_graph=True)
Esempio n. 2
0
    def backward_G_A(self):

        self.loss_G_adversarial_A = loss.adversarial_loss_generator(
            self.fakeApred,
            self.outputApred,
            method='L2',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_reconstruction_A = loss.reconstruction_loss(
            self.outputA,
            self.realA,
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_mask_A = loss.mask_loss(
            self.maskA,
            threshold=self.loss_config['mask_threshold'],
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_A = self.loss_G_adversarial_A + self.loss_G_reconstruction_A + self.loss_G_mask_A

        if self.loss_config['pl_on']:
            self.loss_G_perceptual_A = loss.perceptual_loss(
                self.realA,
                self.fakeA,
                self.vggface,
                self.vggface_for_pl,
                method='L2',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_A += self.loss_G_perceptual_A

        if self.loss_config['edgeloss_on']:
            self.loss_G_edge_A = loss.edge_loss(
                self.outputA,
                self.realA,
                self.mask_eye_A,
                method='L1',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_A += self.loss_G_edge_A

        if self.loss_config['eyeloss_on']:
            self.loss_G_eye_A = loss.eye_loss(
                self.outputA,
                self.realA,
                self.mask_eye_A,
                method='L1',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_A += self.loss_G_eye_A

        self.loss_G_A.backward(retain_graph=True)
Esempio n. 3
0
def train(args):
    tf.enable_eager_execution()
    tf.executing_eagerly()
    tfe = tf.contrib.eager
    writer = tf.contrib.summary.create_file_writer("./log")
    global_step = tf.train.get_or_create_global_step()
    writer.set_as_default()
    dataset = make_dataset("./train.csv", args.batch_size, args.n_class)
    model = MultiSegCaps(n_class=args.n_class)
    #optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=args.lr)
    checkpoint_dir = './models'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    root = tfe.Checkpoint(optimizer=optimizer,
                          model=model,
                          optimizer_step=tf.train.get_or_create_global_step())

    with tf.contrib.summary.record_summaries_every_n_global_steps(50):
        for epoch in range(args.epoch):
            for imgs, lbls in dataset:
                global_step.assign_add(1)
                with tf.GradientTape() as tape:
                    out_seg, reconstruct = model(imgs, lbls)

                    segmentation_loss = tf.losses.softmax_cross_entropy(
                        lbls, out_seg)
                    tf.contrib.summary.scalar('segmentation_loss',
                                              segmentation_loss)
                    #segmetation_loss = weighted_margin_loss(out_seg, lbls, class_weighting=[0,1,1,1,1])

                    reconstruct_loss = reconstruction_loss(reconstruct,
                                                           imgs,
                                                           rs=args.rs)
                    tf.contrib.summary.scalar('reconstruction_loss',
                                              reconstruct_loss)

                    total_loss = segmentation_loss + reconstruct_loss
                    tf.contrib.summary.scalar('total_loss', total_loss)
                    print(total_loss)
                grad = tape.gradient(total_loss, model.variables)
                optimizer.apply_gradients(
                    zip(grad, model.variables),
                    global_step=tf.train.get_or_create_global_step())

            if epoch % 10 == 0:
                root.save(file_prefix=checkpoint_prefix)
Esempio n. 4
0
    def train(self):

        self.set_mode(train=True)

        # prepare dataloader (iterable)
        print('Start loading data...')
        dset = DIGIT('./data', train=True)
        self.data_loader = torch.utils.data.DataLoader(dset, batch_size=self.batch_size, shuffle=True)
        test_dset = DIGIT('./data', train=False)
        self.test_data_loader = torch.utils.data.DataLoader(test_dset, batch_size=self.batch_size, shuffle=True)
        print('test: ', len(test_dset))
        self.N = len(self.data_loader.dataset)
        print('...done')

        # iterators from dataloader
        iterator1 = iter(self.data_loader)
        iterator2 = iter(self.data_loader)

        iter_per_epoch = min(len(iterator1), len(iterator2))

        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch += 1
                iterator1 = iter(self.data_loader)
                iterator2 = iter(self.data_loader)

            # ============================================
            #          TRAIN THE VAE (ENC & DEC)
            # ============================================

            # sample a mini-batch
            XA, XB, index = next(iterator1)  # (n x C x H x W)

            index = index.cpu().detach().numpy()
            if self.use_cuda:
                XA = XA.cuda()
                XB = XB.cuda()

            # zA, zS = encA(xA)
            muA_infA, stdA_infA, logvarA_infA, cate_prob_infA = self.encoderA(XA)

            # zB, zS = encB(xB)
            muB_infB, stdB_infB, logvarB_infB, cate_prob_infB = self.encoderB(XB)

            # read current values

            # zS = encAB(xA,xB) via POE
            cate_prob_POE = torch.exp(
                torch.log(torch.tensor(1 / 10)) + torch.log(cate_prob_infA) + torch.log(cate_prob_infB))

            # latent_dist = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            # (kl_cont_loss, kl_disc_loss, cont_capacity_loss, disc_capacity_loss) = kl_loss_function(self.use_cuda, iteration, latent_dist)

            # kl losses
            #A
            latent_dist_infA = {'cont': (muA_infA, logvarA_infA), 'disc': [cate_prob_infA]}
            (kl_cont_loss_infA, kl_disc_loss_infA, cont_capacity_loss_infA, disc_capacity_loss_infA) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infA)

            loss_kl_infA = kl_cont_loss_infA + kl_disc_loss_infA
            capacity_loss_infA = cont_capacity_loss_infA + disc_capacity_loss_infA

            #B
            latent_dist_infB = {'cont': (muB_infB, logvarB_infB), 'disc': [cate_prob_infB]}
            (kl_cont_loss_infB, kl_disc_loss_infB, cont_capacity_loss_infB, disc_capacity_loss_infB) = kl_loss_function(
                self.use_cuda, iteration, latent_dist_infB, cont_capacity=[0.0, 5.0, 50000, 100.0] , disc_capacity=[0.0, 10.0, 50000, 100.0])

            loss_kl_infB = kl_cont_loss_infB + kl_disc_loss_infB
            capacity_loss_infB = cont_capacity_loss_infB + disc_capacity_loss_infB


            loss_capa = capacity_loss_infB

            # encoder samples (for training)
            ZA_infA = sample_gaussian(self.use_cuda, muA_infA, stdA_infA)
            ZB_infB = sample_gaussian(self.use_cuda, muB_infB, stdB_infB)
            ZS_POE = sample_gumbel_softmax(self.use_cuda, cate_prob_POE)

            # encoder samples (for cross-modal prediction)
            ZS_infA = sample_gumbel_softmax(self.use_cuda, cate_prob_infA)
            ZS_infB = sample_gumbel_softmax(self.use_cuda, cate_prob_infB)

            # reconstructed samples (given joint modal observation)
            XA_POE_recon = self.decoderA(ZA_infA, ZS_POE)
            XB_POE_recon = self.decoderB(ZB_infB, ZS_POE)

            # reconstructed samples (given single modal observation)
            XA_infA_recon = self.decoderA(ZA_infA, ZS_infA)
            XB_infB_recon = self.decoderB(ZB_infB, ZS_infB)

            # loss_recon_infA = F.l1_loss(torch.sigmoid(XA_infA_recon), XA, reduction='sum').div(XA.size(0))
            loss_recon_infA = reconstruction_loss(XA, torch.sigmoid(XA_infA_recon), distribution="bernoulli")
            #
            loss_recon_infB = reconstruction_loss(XB, torch.sigmoid(XB_infB_recon), distribution="bernoulli")
            #
            loss_recon_POE = \
                F.l1_loss(torch.sigmoid(XA_POE_recon), XA, reduction='sum').div(XA.size(0)) + \
                F.l1_loss(torch.sigmoid(XB_POE_recon), XB, reduction='sum').div(XB.size(0))
            #

            loss_recon = loss_recon_infB

            # total loss for vae
            vae_loss = loss_recon + loss_capa

            # update vae
            self.optim_vae.zero_grad()
            vae_loss.backward()
            self.optim_vae.step()



            # print the losses
            if iteration % self.print_iter == 0:
                prn_str = ( \
                                      '[iter %d (epoch %d)] vae_loss: %.3f ' + \
                                      '(recon: %.3f, capa: %.3f)\n' + \
                                      '    rec_infA = %.3f, rec_infB = %.3f, rec_POE = %.3f\n' + \
                                      '    kl_infA = %.3f, kl_infB = %.3f' + \
                                      '    cont_capacity_loss_infA = %.3f, disc_capacity_loss_infA = %.3f\n' + \
                                      '    cont_capacity_loss_infB = %.3f, disc_capacity_loss_infB = %.3f\n'
                          ) % \
                          (iteration, epoch,
                           vae_loss.item(), loss_recon.item(), loss_capa.item(),
                           loss_recon_infA.item(), loss_recon_infB.item(), loss_recon.item(),
                           loss_kl_infA.item(), loss_kl_infB.item(),
                           cont_capacity_loss_infA.item(), disc_capacity_loss_infA.item(),
                           cont_capacity_loss_infB.item(), disc_capacity_loss_infB.item(),
                           )
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str,))
                    record.close()

            # save model parameters
            if iteration % self.ckpt_save_iter == 0:
                self.save_checkpoint(iteration)

            # save output images (recon, synth, etc.)
            if iteration % self.output_save_iter == 0:
                # self.save_embedding(iteration, index, muA_infA, muB_infB, muS_infA, muS_infB, muS_POE)

                # 1) save the recon images
                self.save_recon(iteration)

                # self.save_recon2(iteration, index, XA, XB,
                #     torch.sigmoid(XA_infA_recon).data,
                #     torch.sigmoid(XB_infB_recon).data,
                #     torch.sigmoid(XA_POE_recon).data,
                #     torch.sigmoid(XB_POE_recon).data,
                #     muA_infA, muB_infB, muS_infA, muS_infB, muS_POE,
                #     logalpha, logalphaA, logalphaB
                # )
                z_A, z_B, z_S = self.get_stat()

                #
                #
                #
                # # 2) save the pure-synthesis images
                # # self.save_synth_pure( iteration, howmany=100 )
                # #
                # # 3) save the cross-modal-synthesis images
                # self.save_synth_cross_modal(iteration, z_A, z_B, howmany=3)
                #
                # # 4) save the latent traversed images
                self.save_traverseB(iteration, z_A, z_B, z_S)

                # self.get_loglike(logalpha, logalphaA, logalphaB)

                # # 3) save the latent traversed images
                # if self.dataset.lower() == '3dchairs':
                #     self.save_traverse(iteration, limb=-2, limu=2, inter=0.5)
                # else:
                #     self.save_traverse(iteration, limb=-3, limu=3, inter=0.1)

            if iteration % self.eval_metrics_iter == 0:
                self.save_synth_cross_modal(iteration, z_A, z_B, train=False, howmany=3)

            # (visdom) insert current line stats
            if self.viz_on and (iteration % self.viz_ll_iter == 0):
                self.line_gather.insert(iter=iteration,
                                        recon_both=loss_recon_POE.item(),
                                        recon_A=loss_recon_infA.item(),
                                        recon_B=loss_recon_infB.item(),
                                        kl_A=loss_kl_infA.item(),
                                        kl_B=loss_kl_infB.item(),
                                        cont_capacity_loss_infA=cont_capacity_loss_infA.item(),
                                        disc_capacity_loss_infA=disc_capacity_loss_infA.item(),
                                        cont_capacity_loss_infB=cont_capacity_loss_infB.item(),
                                        disc_capacity_loss_infB=disc_capacity_loss_infB.item()
                                        )

            # (visdom) visualize line stats (then flush out)
            if self.viz_on and (iteration % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()
feat = torch.from_numpy(feat.todense().astype(np.float32))

############## init model ##############
gcn_vae = GraphAE(features_dim, hidden_dim, out_dim, bias=False, dropout=0.0)
optimizer_vae = torch.optim.Adam(gcn_vae.parameters(), lr=1e-2)

mlp = MLP(features_dim, hidden_dim, out_dim, dropout=0.0)
optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-2)

for batch_idx in range(num_iters):
    # train GCN
    optimizer_vae.zero_grad()
    gcn_vae.train()
    z = gcn_vae(adj_norm, feat)
    adj_h = torch.mm(z, z.t())
    vae_train_loss = reconstruction_loss(adj_label, adj_h, norm)
    vae_train_loss.backward()
    optimizer_vae.step()

    #train mlp
    optimizer_mlp.zero_grad()
    mlp.train()
    z_mean, z_log_std = mlp(feat)
    mlp_train_loss = vae_loss(z_mean, z_log_std, adj_label)
    mlp_train_loss.backward()
    optimizer_mlp.step()
    print('GCN [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        batch_idx, num_iters, 100. * batch_idx / num_iters,
        vae_train_loss.item()))

    if batch_idx % 10 == 0:
Esempio n. 6
0
            # # test r-gcn
            gcn_step.eval()
            z_mean_old, z_log_std_old, z_mean_new, z_log_std_new = gcn_step(torch.from_numpy(g_adj), feat, feat_all[feat.size()[0]:, :])
            z_mean = torch.cat((z_mean_old, z_mean_new))
            z_log_std = torch.cat((z_log_std_old, z_log_std_new))

            adj_h = sample_reconstruction(z_mean, z_log_std)
            if refit:
                adj_hat = (adj_h > 0).type(torch.FloatTensor)
                adj_hat[:feat.size(0), :feat.size(0)] = torch.from_numpy(g_adj)
                z_mean, z_log = gcn_step(adj_hat, feat_all)
                adj_h = sample_reconstruction(z_mean, z_log_std)


            test_loss = reconstruction_loss(adj_truth_all, adj_h, mask, test=True)
            auc_rgcn = get_roc_auc_score(adj_truth_all, adj_h, mask)
            ap_rgcn = get_average_precision_score(adj_truth_all, adj_h, mask)

            info = 'R-GCN test loss: {:.6f}'.format(test_loss)
            print(info)
            logging.info(info)


            # test original gcn
            gcn_vae.eval()
            adj_vae_norm = torch.eye(feat_all.size()[0])
            adj_vae_norm[:feat.size()[0], :feat.size()[0]] = adj_feed_norm
            z_mean, z_log_std = gcn_vae(adj_vae_norm, feat_all)
            adj_h = sample_reconstruction(z_mean, z_log_std)
            test_loss = reconstruction_loss(adj_truth_all, adj_h, mask, test=True)
Esempio n. 7
0
    def __init__(self, args):
        # AttributeErrors not handled
        # fail early if the encoder and decoder are not found
        self.encoder = getattr(encoders, args.encoder)
        self.decoder = getattr(decoders, args.decoder)

        # graph definition
        with tf.Graph().as_default() as g:
            # placeholders
            # target frame, number of channels is always one for GIF
            self.T = tf.placeholder(tf.float32,
                                    shape=(args.batch_size, args.crop_height,
                                           args.crop_width, 1))

            self.Z_in = tf.placeholder(tf.float32,
                                       shape=(args.batch_size, args.z_dim))

            # input frame(s), this depends on network parameters
            if args.crop_pos is not None:
                # non FCN case
                if args.window_size > 1:
                    self.X = tf.placeholder(
                        tf.float32,
                        shape=(args.batch_size, args.window_size,
                               args.crop_height, args.crop_width, 1))
                else:
                    self.X = tf.placeholder(tf.float32,
                                            shape=(args.batch_size,
                                                   args.crop_height,
                                                   args.crop_width, 1))
            else:
                # FCN case
                if args.window_size > 1:
                    self.X = tf.placeholder(tf.float32,
                                            shape=(1, args.window_size, None,
                                                   None, 1))
                else:
                    self.X = tf.placeholder(tf.float32,
                                            shape=(1, None, None, 1))
                # TODO: remove this once FCN networks have been added
                raise NotImplementedError

            # feed into networks, with their own unique name_scopes
            if args.encoder == "vae_encoder":
                mu, sigma = self.encoder(self.X, args)
                self.Z = z_sample(mu, sigma)
            else:
                mu, sigma = None, None
                self.Z = self.encoder(self.X, args)

            self.T_hat = self.decoder(self.Z, args, reuse=tf.AUTO_REUSE)
            self.decompression_op = self.decoder(self.Z_in,
                                                 args,
                                                 reuse=tf.AUTO_REUSE)

            # calculate loss
            with tf.name_scope("loss"):
                mu = mu if mu is not None else None
                sigma = sigma if sigma is not None else None
                self.loss_op = reconstruction_loss(self.T_hat, self.T,
                                                   args.loss, mu, sigma,
                                                   args.l1_reg_strength,
                                                   args.l2_reg_strength)

            # optimizer
            with tf.name_scope("optim"):
                self.optimizer = tf.train.AdamOptimizer(
                    learning_rate=args.learning_rate)
                # grads = optimizer.compute_gradients(loss_op)
                self.train_op = self.optimizer.minimize(self.loss_op)

            # summaries
            with tf.name_scope("summary"):
                tf.summary.scalar("sumary_loss", self.loss_op)
                tf.summary.image("sumary_target", self.T)
                tf.summary.image("sumary_recon", self.T_hat)
                self.summary_op = tf.summary.merge_all()

            with tf.name_scope("init"):
                self.init_op = tf.global_variables_initializer()

            self.graph = g
Esempio n. 8
0
        "b2": Weights("bias_mean_hidden", [1, LATENT_SPACE_DIM]),  # 2
        "b3": Weights("bias_std_hidden", [1, LATENT_SPACE_DIM]),  # 2
        "b4": Weights("bias_matrix_decoder_hidden", [1, NN_DIM]),  # 512
        "b5": Weights("bias_decoder", [1, IMAGE_DIM])  # 784
    }

    for i in tqdm(range(1, epochs)):

        print(f"Epoch {i}")

        for batch, _ in trainloader:

            batch = batch.view(batch.shape[0], -1)
            decoder_output, mean, std = forward_propogate(
                batch, weight_list, bias_list)
            total_loss = ALPHA * reconstruction_loss(
                batch, decoder_output) + BETA * kl_divergence_loss(mean, std)
            total_loss.sum().backward()

            for weight in weight_list:
                weight_list[weight]._get_weight(
                ).data = weight_list[weight]._get_weight(
                ).data - learning_rate * weight_list[weight]._get_weight(
                ).grad.data
                weight_list[weight]._get_weight().grad.data.zero_()

            for bias in bias_list:
                bias_list[bias]._get_weight(
                ).data = bias_list[bias]._get_weight(
                ).data - learning_rate * bias_list[bias]._get_weight().grad.data
                bias_list[bias]._get_weight().grad.data.zero_()
Esempio n. 9
0
def train_glc():
    train_img_paths = get_data(os.path.join(FLAGS.train_file))
    slim.get_or_create_global_step()
    inputs = gen_inputs(FLAGS)
    image = inputs['image_bch']
    mask = inputs['mask_bch']
    gen_output, _ = glc_gen.generator(image, mask, mean_fill=FLAGS.mean_fill)

    ##################
    ## Optimisation ##
    ##################

    # Discriminator loss
    # dis_input = tf.concat([gen_output, image], axis=0)
    # dis_mask = tf.concat([mask]*2, axis=0)
    # dis_labels = tf.concat([tf.zeros(shape=(FLAGS.batch_size,)),
    #                         tf.ones(shape=FLAGS.batch_size,)], axis=0)
    pred_gen_labels, _ = glc_dis.discriminator(gen_output, mask, FLAGS)
    pred_real_labels, _ = glc_dis.discriminator(image, mask, FLAGS, reuse=True)
    # pred_dis_labels, _ = glc_dis.discriminator(dis_input, dis_mask, FLAGS)
    # discriminator_loss = loss.discriminator_minimax_loss(pred_dis_labels, dis_labels)
    # discriminator_loss_library = loss.tf_generator_minmax_disc_loss(tf.slice(pred_dis_labels, [0], [FLAGS.batch_size]),tf.slice(
    # pred_dis_labels, [(FLAGS.batch_size)], [FLAGS.batch_size]))
    # Generator loss
    discriminator_loss_library = gan_loss.modified_discriminator_loss(
        pred_real_labels, pred_gen_labels)
    generator_dis_loss_library = gan_loss.modified_generator_loss(
        pred_gen_labels)
    # gen_dis_input = gen_output
    # gen_dis_masks = mask
    # gen_dis_labels = tf.zeros(shape=(FLAGS.batch_size,))
    # pred_gen_dis_labels, _ = glc_dis.discriminator(gen_dis_input, gen_dis_masks, FLAGS,
    #                                                reuse=True)
    #  generator_dis_loss = loss.generator_minimax_loss(pred_gen_dis_labels, gen_dis_labels)
    # generator_dis_loss_library = loss.tf_generator_minmax_gen_loss(pred_gen_dis_labels)
    generator_rec_loss = loss.reconstruction_loss(gen_output, mask, image)
    generator_tot_loss = tf.add(generator_rec_loss,
                                FLAGS.alpha * generator_dis_loss_library,
                                name='gen_total_loss')
    tf.losses.add_loss(generator_tot_loss)

    dis_optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    gen_rec_optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    gen_dis_optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

    dis_train_op = utils.get_train_op_for_scope(discriminator_loss_library,
                                                dis_optimizer, ['glc_dis'],
                                                FLAGS.clip_gradient_norm)

    generator_rec_train_op = utils.get_train_op_for_scope(
        generator_rec_loss, gen_rec_optimizer, ['glc_gen'],
        FLAGS.clip_gradient_norm)

    generator_dis_train_op = utils.get_train_op_for_scope(
        generator_tot_loss, gen_dis_optimizer, ['glc_gen'],
        FLAGS.clip_gradient_norm)
    layers.summarize_collection(tf.GraphKeys.LOSSES)
    loss_summary_op = tf.summary.merge_all()
    with tf.Session() as sess:
        tb_writer = tf.summary.FileWriter(FLAGS.tb_dir + '/train', sess.graph)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        sess.run(inputs['iterator'].initializer,
                 feed_dict={inputs['image_paths']: train_img_paths})

        #  while True:
        try:
            for counter in range(T_TRAIN):
                step = sess.run(slim.get_global_step())
                if counter < T_C:
                    _, loss_summaries, aa = sess.run([
                        generator_rec_train_op, loss_summary_op,
                        generator_rec_loss
                    ])
                else:
                    _, loss_summaries, aa = sess.run([
                        dis_train_op, loss_summary_op,
                        discriminator_loss_library
                    ])
                if counter > T_C + T_D:
                    _, loss_summaries, aa = sess.run([
                        generator_dis_train_op, loss_summary_op,
                        generator_dis_loss_library
                    ])
                tb_writer.add_summary(loss_summaries, step)
                print 'Global_step: {}, Loss: {}'.format(step, aa)
                if step % 100 == 0:
                    saver.save(sess, FLAGS.ckpt_dir)
        except tf.errors.OutOfRangeError:
            pass
    return None
Esempio n. 10
0
    def train(self):
        self.net_mode(train=True)

        pbar = trange(self.global_iter, int(self.max_iter))
        epoch = 0
        while True:
            for iter, (x, _) in enumerate(self.data_loader):
                current_iter = epoch * len(self.data_loader) + iter

                x = Variable(x).to(device)
                x_recon, mu, logvar = self.net(x)
                recon_loss = reconstruction_loss(x, x_recon,
                                                 self.params.distribution)
                total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)

                beta_vae_loss = self.get_loss(recon_loss,
                                              total_kld,
                                              current_iter=current_iter)

                self.optim.zero_grad()
                beta_vae_loss.backward()
                self.optim.step()

                if self.viz_on and current_iter % self.gather_step == 0:
                    self.gather.insert(iter=current_iter,
                                       mu=mu.mean(0).data,
                                       var=logvar.exp().mean(0).data,
                                       recon_loss=recon_loss,
                                       total_kld=total_kld.data,
                                       dim_wise_kld=dim_wise_kld.data,
                                       mean_kld=mean_kld.data)

                if current_iter % self.display_step == 0:
                    pbar.write(
                        '[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'
                        .format(current_iter, recon_loss.item(),
                                total_kld.data[0], mean_kld.data[0]))

                    var = logvar.exp().mean(0).data
                    var_str = ''
                    for j, var_j in enumerate(var):
                        var_str += 'var{}:{:.4f} '.format(j + 1, var_j)
                    pbar.write(var_str)

                    if self.viz_on:
                        self.gather.insert(images=x.data)
                        self.gather.insert(images=F.sigmoid(x_recon).data)
                        self.viz_reconstruction(current_iter)
                        self.viz_lines(current_iter)
                        self.gather.flush()

                    if self.viz_on or self.save_output:
                        self.viz_traverse(current_iter)

                if current_iter % self.save_step == 0:
                    self.save_checkpoint(self.params, 'last')
                    pbar.write(
                        'Saved checkpoint(iter:{})'.format(current_iter))

                if current_iter % 50000 == 0:
                    self.save_checkpoint(self.params, '%08d' % current_iter)

                pbar.update()

            if pbar.n >= int(self.max_iter):
                break
            epoch += 1

        pbar.write("[Training Finished]")
        pbar.close()