Example #1
0
    def save_progress(self):
        gpu_id = self.gpu_ids[0]
        epoch = self.get_current_epoch()

        data_provider = self.data_provider
        enc = self.enc
        dec = self.dec

        enc.train(False)
        dec.train(False)

        ###############
        #TRAINING DATA
        ###############
        train_classes = data_provider.get_classes(
            np.arange(0, data_provider.get_n_dat('train', override=True)),
            'train')
        _, train_inds = np.unique(train_classes.numpy(), return_index=True)

        x = data_provider.get_images(train_inds, 'train').cuda(gpu_id)

        with torch.no_grad():
            xHat = dec(enc(x))

        imgX = tensor2img(x.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgTrainOut = np.concatenate((imgX, imgXHat), 0)

        ###############
        #TESTING DATA
        ###############
        test_classes = data_provider.get_classes(
            np.arange(0, data_provider.get_n_dat('test')), 'test')
        _, test_inds = np.unique(test_classes.numpy(), return_index=True)

        x = data_provider.get_images(test_inds, 'test').cuda(gpu_id)
        with torch.no_grad():
            xHat = dec(enc(x))

        z = list()
        if enc.n_classes > 0:
            class_var = torch.Tensor(
                data_provider.get_classes(test_inds, 'test',
                                          'one_hot').float()).cuda(gpu_id)
            class_var = (class_var - 1) * 25
            z.append(class_var)

        if enc.n_ref > 0:
            ref_var = torch.Tensor(data_provider.get_n_classes(),
                                   enc.n_ref).normal_(0, 1).cuda(gpu_id)
            z.append(ref_var)

        loc_var = torch.Tensor(data_provider.get_n_classes(),
                               enc.n_latent_dim).normal_(0, 1).cuda(gpu_id)
        z.append(loc_var)

        with torch.no_grad():
            x_z = dec(z)

        imgX = tensor2img(x.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgX_z = tensor2img(x_z.data.cpu())
        imgTestOut = np.concatenate((imgX, imgXHat, imgX_z), 0)

        imgOut = np.concatenate((imgTrainOut, imgTestOut))

        scipy.misc.imsave(
            '{0}/progress_{1}.png'.format(self.save_dir, int(epoch - 1)),
            imgOut)

        enc.train(True)
        dec.train(True)

        # pdb.set_trace()
        # zAll = torch.cat(zAll,0).cpu().numpy()

        embedding = torch.cat(self.zAll, 0).cpu().numpy()

        pickle.dump(embedding,
                    open('{0}/embedding_tmp.pkl'.format(self.save_dir), 'wb'))
        pickle.dump(self.logger,
                    open('{0}/logger_tmp.pkl'.format(self.save_dir), 'wb'))

        ### History
        plots.history(self.logger, '{0}/history.png'.format(self.save_dir))

        ### Short History
        plots.short_history(self.logger,
                            '{0}/history_short.png'.format(self.save_dir))

        ### Embedding figure
        plots.embeddings(embedding, '{0}/embedding.png'.format(self.save_dir))

        xHat = None
        x = None
Example #2
0
    def save_progress(self):
        gpu_id = self.gpu_ids[0]
        epoch = self.get_current_epoch()

        data_provider = self.data_provider
        enc = self.enc
        dec = self.dec

        enc.train(False)
        dec.train(False)

        ###############
        # TRAINING DATA
        ###############
        img_inds = np.arange(self.n_display_imgs)

        x, _, _ = data_provider.get_sample("train", img_inds)
        x = x.cuda(gpu_id)

        def xHat2sample(xHat, x):
            if xHat.shape[1] == x.shape[1]:
                pass
            else:
                mu = xHat[:, 0::2, :, :]
                log_var = torch.log(xHat[:, 1::2, :, :])

                xHat = self.reparameterize(mu, log_var, add_noise=True)

            return xHat

        with torch.no_grad():
            z_mu, _ = enc(x)
            xHat = dec(z_mu)
            xHat = xHat2sample(xHat, x)

        imgX = tensor2im(x.data.cpu())
        imgXHat = tensor2im(xHat.data.cpu())
        imgTrainOut = np.concatenate((imgX, imgXHat), 0)

        ###############
        # TESTING DATA
        ###############
        x, _, _ = data_provider.get_sample("validate", img_inds)
        x = x.cuda(gpu_id)

        with torch.no_grad():
            z_mu, _ = enc(x)
            xHat = dec(z_mu)
            xHat = xHat2sample(xHat, x)

        z_mu.normal_()

        with torch.no_grad():
            xHat_z = dec(z_mu)
            xHat_z = xHat2sample(xHat_z, x)

        imgX = tensor2im(x.data.cpu())
        imgXHat = tensor2im(xHat.data.cpu())
        imgXHat_z = tensor2im(xHat_z.data.cpu())
        imgTestOut = np.concatenate((imgX, imgXHat, imgXHat_z), 0)

        imgOut = np.concatenate((imgTrainOut, imgTestOut))

        scipy.misc.imsave(
            "{0}/progress_{1}.png".format(self.save_dir, int(epoch - 1)),
            imgOut)

        embeddings_train = np.concatenate(self.zAll, 0)

        pickle.dump(embeddings_train,
                    open("{0}/embedding.pth".format(self.save_dir), "wb"))
        pickle.dump(
            embeddings_train,
            open(
                "{0}/embedding_{1}.pth".format(self.save_dir,
                                               self.get_current_iter()),
                "wb",
            ),
        )

        pickle.dump(self.logger,
                    open("{0}/logger_tmp.pkl".format(self.save_dir), "wb"))

        # History
        plots.history(self.logger, "{0}/history.png".format(self.save_dir))

        # Short History
        plots.short_history(self.logger,
                            "{0}/history_short.png".format(self.save_dir))

        # Embedding figure
        plots.embeddings(embeddings_train,
                         "{0}/embedding.png".format(self.save_dir))

        def sampler(mode, inds):
            return data_provider.get_sample(mode, inds)[0]

        embeddings_validate = embeddings.get_latent_embeddings(
            enc,
            dec,
            dp=self.data_provider,
            recon_loss=self.crit_recon,
            modes=["validate"],
            batch_size=self.data_provider.batch_size,
            sampler=sampler,
        )
        embeddings_validate["iteration"] = self.get_current_iter()
        embeddings_validate["epoch"] = self.get_current_epoch()

        torch.save(
            embeddings_validate,
            "{}/embeddings_validate_{}.pth".format(self.save_dir,
                                                   self.get_current_iter()),
        )

        xHat = None
        x = None

        enc.train(True)
        dec.train(True)
    def save_progress(self):
        gpu_id = self.gpu_ids[0]
        epoch = self.get_current_epoch()

        data_provider = self.data_provider
        enc = self.enc
        dec = self.dec

        enc.train(False)
        dec.train(False)

        ###############
        # TRAINING DATA
        ###############
        train_classes = data_provider.get_classes(
            np.arange(0, data_provider.get_n_dat("train", override=True)), "train"
        )
        _, train_inds = np.unique(train_classes.numpy(), return_index=True)

        x, classes, x_mesh = data_provider.get_sample(
            "train", train_inds, patched=False
        )
        x = x.cuda(gpu_id)

        classes = classes.type_as(x).long()
        classes_onehot = utils.index_to_onehot(
            classes, self.data_provider.get_n_classes()
        )

        x_mesh = x_mesh.type_as(x)

        def xHat2sample(xHat, x):
            if xHat.shape[1] == x.shape[1]:
                pass
            else:
                mu = xHat[:, 0::2, :, :]
                log_var = torch.log(xHat[:, 1::2, :, :])

                xHat = self.reparameterize(mu, log_var, add_noise=True)

            return xHat

        with torch.no_grad():
            z = enc(x, classes_onehot, x_mesh)
            for i in range(len(z)):
                z[i] = z[i][0]
            xHat = dec([classes_onehot] + z, x_mesh)
            xHat = xHat2sample(xHat, x)

        imgX = tensor2img(x.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgTrainOut = np.concatenate((imgX, imgXHat), 0)

        ###############
        # TESTING DATA
        ###############
        test_classes = data_provider.get_classes(
            np.arange(0, data_provider.get_n_dat("test")), "test"
        )
        _, test_inds = np.unique(test_classes.numpy(), return_index=True)

        x, classes, ref = data_provider.get_sample("test", test_inds, patched=False)
        x = x.cuda(gpu_id)
        classes = classes.type_as(x).long()
        ref = ref.type_as(x)

        with torch.no_grad():
            z = enc(x, classes_onehot, x_mesh)
            for i in range(len(z)):
                z[i] = z[i][0]

            xHat = dec([classes_onehot] + z, x_mesh)
            xHat = xHat2sample(xHat, x)

        for z_sub in z:
            z_sub.normal_()

        with torch.no_grad():
            xHat_z = dec([classes_onehot] + z, x_mesh)
            xHat_z = xHat2sample(xHat_z, x)

        imgX = tensor2img(x.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgXHat_z = tensor2img(xHat_z.data.cpu())
        imgTestOut = np.concatenate((imgX, imgXHat, imgXHat_z), 0)

        imgOut = np.concatenate((imgTrainOut, imgTestOut))

        scipy.misc.imsave(
            "{0}/progress_{1}.png".format(self.save_dir, int(epoch - 1)), imgOut
        )

        embeddings_train = np.concatenate(self.zAll, 0)

        # if we've done this convolutionally, randomly select a subset of points
        embeddings_shape = embeddings_train.shape
        if len(embeddings_shape) > 2:
            # it's convolutional, and we dont really want to save EVERYTHING
            # so we take the first slice

            slices = [slice(0, 1) for i in range(len(embeddings_shape) - 2)]
            slices = tuple(
                [slice(0, embeddings_shape[0]), slice(0, embeddings_shape[1])] + slices
            )

            embeddings_train = np.squeeze(embeddings_train[slices])

        pickle.dump(
            embeddings_train, open("{0}/embedding.pth".format(self.save_dir), "wb")
        )
        pickle.dump(
            embeddings_train,
            open(
                "{0}/embedding_{1}.pth".format(self.save_dir, self.get_current_iter()),
                "wb",
            ),
        )

        pickle.dump(self.logger, open("{0}/logger_tmp.pkl".format(self.save_dir), "wb"))

        # History
        plots.history(self.logger, "{0}/history.png".format(self.save_dir))

        # Short History
        plots.short_history(self.logger, "{0}/history_short.png".format(self.save_dir))

        # Embedding figure
        plots.embeddings(embeddings_train, "{0}/embedding.png".format(self.save_dir))

        # embeddings_validate = embeddings.get_latent_embeddings(
        #     enc,
        #     dec,
        #     dp=self.data_provider,
        #     recon_loss=self.crit_recon,
        #     modes=["validate"],
        #     batch_size=self.data_provider.batch_size,
        # )
        # embeddings_validate["iteration"] = self.get_current_iter()
        # embeddings_validate["epoch"] = self.get_current_epoch()

        # torch.save(
        #     embeddings_validate,
        #     "{}/embeddings_validate_{}.pth".format(
        #         self.save_dir, self.get_current_iter()
        #     ),
        # )

        xHat = None
        x = None

        enc.train(True)
        dec.train(True)
    def save_progress(self):
        gpu_id = self.gpu_ids[0]
        epoch = self.get_current_epoch()

        data_provider = self.data_provider
        enc = self.enc
        dec = self.dec

        enc.train(False)
        dec.train(False)

        ###############
        # TRAINING DATA
        ###############
        train_classes = data_provider.get_classes(
            np.arange(0, data_provider.get_n_dat("train", override=True)),
            "train")
        _, train_inds = np.unique(train_classes.numpy(), return_index=True)

        x, classes, ref = data_provider.get_sample("train", train_inds)
        x = x.cuda(gpu_id)

        classes = classes.type_as(x).long()
        classes_onehot = utils.index_to_onehot(
            classes, self.data_provider.get_n_classes())

        ref = ref.type_as(x)

        def xHat2sample(xHat, x):
            if xHat.shape[1] == x.shape[1]:
                pass
            else:
                mu = xHat[:, 0::2, :, :]
                log_var = torch.log(xHat[:, 1::2, :, :])

                xHat = bvae.reparameterize(mu, log_var, add_noise=True)

            return xHat

        with torch.no_grad():
            z = enc(x, classes_onehot)
            for i in range(len(z)):
                z[i] = z[i][0]
            xHat = dec([classes_onehot] + z)
            xHat = xHat2sample(xHat, x)

        imgX = tensor2img(x.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgTrainOut = np.concatenate((imgX, imgXHat), 0)

        ###############
        # TESTING DATA
        ###############
        test_classes = data_provider.get_classes(
            np.arange(0, data_provider.get_n_dat("test")), "test")
        _, test_inds = np.unique(test_classes.numpy(), return_index=True)

        x, classes, ref = data_provider.get_sample("test", test_inds)
        x = x.cuda(gpu_id)
        classes = classes.type_as(x).long()
        ref = ref.type_as(x)

        with torch.no_grad():
            z = enc(x, classes_onehot)
            for i in range(len(z)):
                z[i] = z[i][0]

            xHat = dec([classes_onehot] + z)
            xHat = xHat2sample(xHat, x)

        for z_sub in z:
            z_sub.normal_()

        with torch.no_grad():
            xHat_z = dec([classes_onehot] + z)
            xHat_z = xHat2sample(xHat_z, x)

        imgX = tensor2img(x.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgXHat_z = tensor2img(xHat_z.data.cpu())
        imgTestOut = np.concatenate((imgX, imgXHat, imgXHat_z), 0)

        imgOut = np.concatenate((imgTrainOut, imgTestOut))

        scipy.misc.imsave(
            "{0}/progress_{1}.png".format(self.save_dir, int(epoch - 1)),
            imgOut)

        # embeddings_test = embeddings.get_latent_embeddings(
        #     enc,
        #     dec,
        #     dp=self.data_provider,
        #     recon_loss=self.crit_recon,
        #     modes=["test"],
        #     batch_size=self.data_provider.batch_size,
        # )
        # embeddings_test["iteration"] = self.get_current_iter()
        # embeddings_test["epoch"] = self.get_current_epoch()

        # torch.save(
        #     embeddings_test,
        #     "{}/embeddings_test_{}.pth".format(
        #         self.save_dir, self.get_current_iter()
        #     ),
        # )

        embeddings_train = np.concatenate(self.zAll, 0)
        embeddings_train = embeddings_train.reshape(embeddings_train.shape[0],
                                                    -1)

        pickle.dump(embeddings_train,
                    open("{0}/embedding.pth".format(self.save_dir), "wb"))
        pickle.dump(
            embeddings_train,
            open(
                "{0}/embedding_{1}.pth".format(self.save_dir,
                                               self.get_current_iter()),
                "wb",
            ),
        )

        pickle.dump(self.logger,
                    open("{0}/logger_tmp.pkl".format(self.save_dir), "wb"))

        # History
        plots.history(self.logger, "{0}/history.png".format(self.save_dir))

        # Short History
        plots.short_history(self.logger,
                            "{0}/history_short.png".format(self.save_dir))

        # Embedding figure
        plots.embeddings(embeddings_train,
                         "{0}/embedding.png".format(self.save_dir))

        embeddings_validate = embeddings.get_latent_embeddings(
            enc,
            dec,
            dp=self.data_provider,
            recon_loss=self.crit_recon,
            modes=["validate"],
            batch_size=self.data_provider.batch_size,
        )
        embeddings_validate["iteration"] = self.get_current_iter()
        embeddings_validate["epoch"] = self.get_current_epoch()

        torch.save(
            embeddings_validate,
            "{}/embeddings_validate_{}.pth".format(self.save_dir,
                                                   self.get_current_iter()),
        )

        xHat = None
        x = None

        enc.train(True)
        dec.train(True)
Example #5
0
    def save_progress(self):
        gpu_id = self.gpu_ids[0]
        epoch = self.get_current_epoch()

        data_provider = self.data_provider
        enc = self.enc
        dec = self.dec

        enc.train(False)
        dec.train(False)

        ###############
        # TRAINING DATA
        ###############
        img_inds = np.arange(self.n_display_imgs)

        x, _, _ = data_provider.get_sample("train", img_inds)
        x = x.cuda(gpu_id)

        with torch.no_grad():
            z = enc(x)
            xHat = dec(z)

        imgX = tensor2im(x.data.cpu())
        imgXHat = tensor2im(xHat.data.cpu())
        imgTrainOut = np.concatenate((imgX, imgXHat), 0)

        ###############
        # TESTING DATA
        ###############
        x, _, _ = data_provider.get_sample("validate", img_inds)
        x = x.cuda(gpu_id)

        with torch.no_grad():
            z = enc(x)
            xHat = dec(z)

        imgX = tensor2im(x.data.cpu())
        imgXHat = tensor2im(xHat.data.cpu())

        imgTestOut = np.concatenate((imgX, imgXHat), 0)

        imgOut = np.concatenate((imgTrainOut, imgTestOut))

        scipy.misc.imsave(
            "{0}/progress_{1}.png".format(self.save_dir, int(epoch - 1)), imgOut
        )

        embeddings_train = np.concatenate(self.zAll, 0)

        pickle.dump(
            embeddings_train, open("{0}/embedding.pth".format(self.save_dir), "wb")
        )
        pickle.dump(
            embeddings_train,
            open(
                "{0}/embedding_{1}.pth".format(self.save_dir, self.get_current_iter()),
                "wb",
            ),
        )

        pickle.dump(self.logger, open("{0}/logger_tmp.pkl".format(self.save_dir), "wb"))

        # History
        plots.history(self.logger, "{0}/history.png".format(self.save_dir))

        # Short History
        plots.short_history(self.logger, "{0}/history_short.png".format(self.save_dir))

        # Embedding figure
        plots.embeddings(embeddings_train, "{0}/embedding.png".format(self.save_dir))

        enc.train(True)
        dec.train(True)
Example #6
0
    def save_progress(self):
        gpu_id = self.gpu_ids[0]
        epoch = self.get_current_epoch()

        data_provider = self.data_provider
        enc = self.enc
        dec = self.dec

        enc.train(False)
        dec.train(False)

        ###############
        # TRAINING DATA
        ###############
        train_classes = data_provider.get_classes(
            np.arange(0, data_provider.get_n_dat("train", override=True)), "train"
        )
        _, train_inds = np.unique(train_classes.numpy(), return_index=True)

        x, classes, ref = data_provider.get_sample("train", train_inds)
        x = x.cuda(gpu_id)

        classes = classes.type_as(x).long()
        classes_onehot = utils.index_to_onehot(
            classes, self.data_provider.get_n_classes()
        )

        ref = ref.type_as(x)

        with torch.no_grad():
            z = enc(x, classes_onehot)
            for i in range(len(z)):
                z[i] = z[i][0]
            xHat = dec([classes_onehot] + z)

        imgX = tensor2img(x.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgTrainOut = np.concatenate((imgX, imgXHat), 0)

        ###############
        # TESTING DATA
        ###############
        test_classes = data_provider.get_classes(
            np.arange(0, data_provider.get_n_dat("test")), "test"
        )
        _, test_inds = np.unique(test_classes.numpy(), return_index=True)

        x, classes, ref = data_provider.get_sample("test", test_inds)
        x = x.cuda(gpu_id)
        classes = classes.type_as(x).long()
        ref = ref.type_as(x)

        x = data_provider.get_images(test_inds, "test").cuda(gpu_id)
        with torch.no_grad():
            z = enc(x, classes_onehot)
            for i in range(len(z)):
                z[i] = z[i][0]

            xHat = dec([classes_onehot] + z)

        for z_sub in z:
            z_sub.normal_()

        with torch.no_grad():
            xHat_z = dec([classes_onehot] + z)

        imgX = tensor2img(x.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgXHat_z = tensor2img(xHat_z.data.cpu())
        imgTestOut = np.concatenate((imgX, imgXHat, imgXHat_z), 0)

        imgOut = np.concatenate((imgTrainOut, imgTestOut))

        scipy.misc.imsave(
            "{0}/progress_{1}.png".format(self.save_dir, int(epoch - 1)), imgOut
        )

        enc.train(True)
        dec.train(True)

        embeddings = np.concatenate(self.zAll, 0)

        #         pickle.dump(embeddings, open("{0}/embedding.pth".format(self.save_dir), "wb"))
        #         pickle.dump(
        #             embeddings,
        #             open(
        #                 "{0}/embedding_{1}.pth".format(self.save_dir, self.get_current_iter()),
        #                 "wb",
        #             ),
        #         )

        pickle.dump(self.logger, open("{0}/logger_tmp.pkl".format(self.save_dir), "wb"))

        # History
        plots.history(self.logger, "{0}/history.png".format(self.save_dir))

        # Short History
        plots.short_history(self.logger, "{0}/history_short.png".format(self.save_dir))

        # Embedding figure
        plots.embeddings(embeddings, "{0}/embedding.png".format(self.save_dir))

        xHat = None
        x = None
Example #7
0
    def save_progress(self, enc, dec, data_provider, logger, embedding, opt):
        gpu_id = self.gpu_ids[0]

        epoch = max(logger.log['epoch'])

        enc.train(False)
        dec.train(False)

    #     pdb.set_trace()
        train_classes = data_provider.get_classes(np.arange(0, data_provider.get_n_dat('train')), 'train')
        _, train_inds = np.unique(train_classes.numpy(), return_index=True)

        x_in, x_out, y_in, y_out = self.get_data(data_provider, train_inds, train_or_test = 'train')
        
        with torch.no_grad():
            xHat = dec(enc(x_in, y_in), y_out)
            
        imgX = tensor2img(x_out.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        imgTrainOut = np.concatenate((imgX, imgXHat), 0)

        test_classes = data_provider.get_classes(np.arange(0, data_provider.get_n_dat('test')), 'test')
        _, test_inds = np.unique(test_classes.numpy(), return_index=True)

        x_in, x_out, y_in, y_out = self.get_data(data_provider, test_inds, train_or_test = 'test')
        
        with torch.no_grad():
            xHat = dec(enc(x_in, y_in), y_out)

        z = list()
        if self.n_classes > 0:
            class_var = torch.Tensor(data_provider.get_classes(test_inds, 'test', 'one_hot').float()).cuda(gpu_id)
            class_var = (class_var-1) * 25
            z.append(class_var)

        if self.n_ref > 0:
            ref_var = torch.Tensor(data_provider.get_n_classes(), self.n_ref).normal_(0,1).cuda(gpu_id)
            z.append(ref_var)

        loc_var = torch.Tensor(data_provider.get_n_classes(), self.n_latent_dim).normal_(0,1).cuda(gpu_id)
        z.append(loc_var)

        # x_z = dec(z, y_out)

        imgX = tensor2img(x_out.data.cpu())
        imgXHat = tensor2img(xHat.data.cpu())
        # imgX_z = tensor2img(x_z.data.cpu())
        # imgTestOut = np.concatenate((imgX, imgXHat, imgX_z), 0)

        imgTestOut = np.concatenate((imgX, imgXHat), 0)
        
        imgOut = np.concatenate((imgTrainOut, imgTestOut))

        scipy.misc.imsave('{0}/progress_{1}.png'.format(opt.save_dir, int(epoch)), imgOut)

        enc.train(True)
        dec.train(True)

        # pdb.set_trace()
        # zAll = torch.cat(zAll,0).cpu().numpy()

        pickle.dump(embedding, open('{0}/embedding_tmp.pkl'.format(opt.save_dir), 'wb'))
        pickle.dump(logger, open('{0}/logger_tmp.pkl'.format(opt.save_dir), 'wb'))

        ### History
        plots.history(logger, '{0}/history.png'.format(opt.save_dir))
        
        ### Short History
        plots.short_history(logger, '{0}/history_short.png'.format(opt.save_dir))
        
        ### Embedding figure
        plots.embeddings(embedding, '{0}/embedding.png'.format(opt.save_dir))

        xHat = None
        x = None