def write_model(out_dir, model_name, model):
     filepath_mdl = os.path.join(out_dir, model_name + ".mdl")
     snapshot = {
         "arch": type(model).__name__,
         "z_dim": model.z_dim,
         "input_size": model.input_size,
         "state_dict": model.state_dict(),
     }
     io.makedirs(filepath_mdl)
     torch.save(snapshot, filepath_mdl)
示例#2
0
    def __init__(self, opts):
        self.opts = opts
        self.device = opts['device']

        # Logdir
        self.logdir = os.path.join(opts['logdir'], opts['exp_name'],
                                   opts['variant_name'])
        io.makedirs(self.logdir)

        # Set seeds
        rn = utils.set_seeds(opts['seed'])

        self.model = MetaSim(opts).to(self.device)
        self.generator = self.model.generator

        tasknet_class = get_tasknet(opts['dataset'])
        self.tasknet = tasknet_class(opts['task']).to(
            self.opts['task']['device'])

        # Data
        sgl = get_scene_graph_loader(opts['dataset'])
        self.scene_graph_dataset = sgl(self.generator,
                                       self.opts['epoch_length'])

        # Rendering layer
        self.renderer = RenderLayer(self.generator, self.device)

        # MMD
        self.mmd = MMDInception(device=self.device,
                                resize_input=self.opts['mmd_resize_input'],
                                include_image=False,
                                dims=self.opts['mmd_dims'])

        dl = get_loader(opts['dataset'])
        self.target_dataset = dl(self.opts['task']['val_root'])
        # In the paper, this is different
        # than the data used to get task net acc.
        # Keeping it the same here for simplicity to
        # reduce memory overhead. To do this correctly,
        # generate another copy of the target data
        # and use it for MMD computation.

        # Optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=opts['optim']['lr'],
            weight_decay=opts['optim']['weight_decay'])

        # LR scheduler
        self.lr_sched = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=opts['optim']['lr_decay'],
            gamma=opts['optim']['lr_decay_gamma'])
def generate_data(config):
    attr = config['attributes']
    generator_class = get_generator(attr['dataset'])
    generator = generator_class(config)

    # vars and housekeeping
    out_dir = attr['output_dir']
    n_samples = attr['num_samples']

    # out directory
    io.makedirs(out_dir)
    io.write_json(config, os.path.join(out_dir, 'config.json'))

    # generate
    io.generate_data(generator, out_dir, n_samples)
 def reconstruct_fixed_samples(self):
     out_dir = os.path.join(cfg.REPORT_DIR, "reconstructions",
                            self.session_name)
     # reconstruct some fixed images from training and validation set (if available)
     for phase, b in self.fixed_batch.items():
         b = self.fixed_batch[phase]
         f = 1 if b.images.shape[-1] < 512 else 0.5
         img = vis_reconstruction(self.saae,
                                  b.images,
                                  landmarks=b.landmarks,
                                  ncols=5,
                                  fx=f,
                                  fy=f)
         filename = f"reconst_{phase}-{self.session_name}_{self.epoch+1}.jpg"
         img_filepath = os.path.join(out_dir, phase, filename)
         io.makedirs(img_filepath)
         cv2.imwrite(img_filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
示例#5
0
    def check_paths(path, path_local):
        """return 'path' if 'path_local' is not defined"""
        if not os.path.exists(path):
            raise IOError(
                f"Could not find datset {dsname}. Invalid path '{path}'.")
        if not path_local:
            path_local = path
        if not os.path.exists(path_local):
            from utils.io import makedirs

            try:
                makedirs(path_local)
            except:
                print(
                    f"Could not create cache directory for dataset {dsname}.")
                raise
        return path, path_local
示例#6
0
def read_openface_detection(lmFilepath,
                            numpy_lmFilepath=None,
                            from_sequence=False,
                            use_cache=True,
                            return_num_faces=False,
                            expected_face_center=None):
    num_faces_in_image = 0
    try:
        if numpy_lmFilepath is not None:
            npfile = numpy_lmFilepath + '.npz'
        else:
            npfile = lmFilepath + '.npz'
        if os.path.isfile(npfile) and use_cache:
            try:
                data = np.load(npfile)
                of_conf, landmarks, pose = [data[arr] for arr in data.files]
                if of_conf > 0:
                    num_faces_in_image = 1
            except:
                print('Could not open file {}'.format(npfile))
                raise
        else:
            if from_sequence:
                lmFilepath = lmFilepath.replace('features',
                                                'features_sequence')
                lmDir, fname = os.path.split(lmFilepath)
                clip_name = os.path.split(lmDir)[1]
                lmFilepath = os.path.join(lmDir, clip_name)
                features = pd.read_csv(lmFilepath + '.csv',
                                       skipinitialspace=True)
                frame_num = int(os.path.splitext(fname)[0])
                features = features[features.frame == frame_num]
            else:
                features = pd.read_csv(lmFilepath + '.csv',
                                       skipinitialspace=True)
            features.sort_values('confidence', ascending=False, inplace=True)
            selected_face_id = 0
            num_faces_in_image = len(features)
            if num_faces_in_image > 1 and expected_face_center is not None:
                max_face_size = 0
                min_distance = 1000
                for fid in range(len(features)):
                    face = features.iloc[fid]
                    # if face.confidence < 0.2:
                    #     continue
                    landmarks_x = face.as_matrix(
                        columns=['x_{}'.format(i) for i in range(68)])
                    landmarks_y = face.as_matrix(
                        columns=['y_{}'.format(i) for i in range(68)])

                    landmarks = np.vstack((landmarks_x, landmarks_y)).T
                    face_center = landmarks.mean(axis=0)
                    distance = ((face_center -
                                 expected_face_center)**2).sum()**0.5
                    if distance < min_distance:
                        min_distance = distance
                        selected_face_id = fid

            try:
                face = features.iloc[selected_face_id]
            except KeyError:
                face = features
            of_conf = face.confidence
            landmarks_x = face.as_matrix(
                columns=['x_{}'.format(i) for i in range(68)])
            landmarks_y = face.as_matrix(
                columns=['y_{}'.format(i) for i in range(68)])
            landmarks = np.vstack((landmarks_x, landmarks_y)).T
            pitch = face.pose_Rx
            yaw = face.pose_Ry
            roll = face.pose_Rz
            pose = np.array((pitch, yaw, roll), dtype=np.float32)
            if numpy_lmFilepath is not None:
                makedirs(npfile)
            np.savez(npfile, of_conf, landmarks, pose)
    except IOError as e:
        # raise IOError("\tError: Could not load landmarks from file {}!".format(lmFilepath))
        # pass
        # print(e)
        of_conf = 0
        landmarks = np.zeros((68, 2), dtype=np.float32)
        pose = np.zeros(3, dtype=np.float32)

    result = [of_conf, landmarks.astype(np.float32), pose]
    if return_num_faces:
        result += [num_faces_in_image]
    return result
示例#7
0
def get_face(filename,
             fullsize_img_dir,
             cropped_img_dir,
             landmarks,
             pose=None,
             bb=None,
             size=(cfg.CROP_SIZE, cfg.CROP_SIZE),
             use_cache=True,
             cropper=None):
    filename_noext = os.path.splitext(filename)[0]
    crop_filepath = os.path.join(cropped_img_dir, filename_noext + '.jpg')
    is_cached_crop = False
    if use_cache and os.path.isfile(crop_filepath):
        try:
            img = io.imread(crop_filepath)
        except:
            raise IOError(
                "\tError: Could not cropped image {}!".format(crop_filepath))
        if img.shape[:2] != size:
            img = cv2.resize(img, size, interpolation=cv2.INTER_CUBIC)
        is_cached_crop = True
    else:
        # Load image from dataset
        img_path = os.path.join(fullsize_img_dir, filename)
        try:
            img = io.imread(img_path)
        except:
            raise IOError("\tError: Could not load image {}!".format(img_path))
        if len(img.shape) == 2 or img.shape[2] == 1:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        assert (img.shape[2] == 3)

    if (landmarks is None
            or not landmarks.any()) and not 'crops_celeba' in cropped_img_dir:
        assert (bb is not None)
        # Fall back to bounding box if no landmarks found
        # print('falling back to bounding box')
        crop = face_processing.crop_by_bb(img,
                                          face_processing.scale_bb(bb,
                                                                   f=1.075),
                                          size=size)
    else:

        if 'crops_celeba' in cropped_img_dir:
            if is_cached_crop:
                crop = img
            else:
                crop = face_processing.crop_celeba(img, size)
        else:
            crop, landmarks, pose = face_processing.crop_face(
                img,
                landmarks,
                img_already_cropped=is_cached_crop,
                pose=pose,
                output_size=size,
                crop_by_eye_mouth_dist=cfg.CROP_BY_EYE_MOUTH_DIST,
                align_face_orientation=cfg.CROP_ALIGN_ROTATION,
                crop_square=cfg.CROP_SQUARE)
    if use_cache and not is_cached_crop:
        makedirs(crop_filepath)
        io.imsave(crop_filepath, crop)

    return crop, landmarks, pose
示例#8
0
def save_bbox(bbox, crop_filepath):
    x1,y1, x2,y2 = bbox
    bbox_filepath = crop_filepath.replace(file_ext_crops, '.bbx')
    makedirs(bbox_filepath)
    np.savetxt(bbox_filepath, np.array([x1, y1, x2-x1, y2-y1], dtype=np.int), fmt='%d')
示例#9
0
    def get_face(self, filename, fullsize_img_dir, cropped_img_root, crop_type='tight',
                 landmarks=None, pose=None, bb=None, size=(cfg.CROP_SIZE, cfg.CROP_SIZE),
                 use_cache=True, detect_face=False, aligned=False, id=None):

        # assert(not detect_face or crop_type == 'loose')


        load_fullsize = False
        loose_bbox = None

        if crop_type=='fullsize':
            load_fullsize = True
        else:
            crop_dir = crop_type
            if detect_face:
                crop_dir += '_det'
            if not aligned:
                crop_dir += '_noalign'
            filename_noext = os.path.splitext(filename)[0]
            if id is not None:
                filename_noext += '.{:07d}'.format(id)
            cache_filepath = os.path.join(cropped_img_root, crop_dir, filename_noext + file_ext_crops)

            is_cached_crop = False
            if use_cache and os.path.isfile(cache_filepath):
                # Load cached crops
                try:
                    img = io.imread(cache_filepath)
                except:
                    print("\tError: Could load not cropped image {}!".format(cache_filepath))
                    print("\tDeleting file and loading fullsize image.")
                    os.remove(cache_filepath)
                    load_fullsize = True

                is_cached_crop = True
                if crop_type=='loose':
                    [x,y, w,h] = np.loadtxt(cache_filepath.replace(file_ext_crops, '.bbx'))
                    loose_bbox = np.array([x,y, x+w, y+h], dtype=int)
            else:
                load_fullsize = True

        assert(detect_face or landmarks is not None or bb is not None)

        if load_fullsize:
            # Load fullsize image from dataset
            # t = time.time()
            img_path = os.path.join(fullsize_img_dir, filename)
            try:
                img = io.imread(img_path)
            except:
                raise IOError("\tError: Could not load image {}!".format(img_path))
            if len(img.shape) == 2 or img.shape[2] == 1:
                img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
            if img.shape[2] == 4:
                print(filename, "converting RGBA to RGB...")
                img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
            assert img.shape[2] == 3, "{}, invalid format: {}".format(img_path, img.shape)
            # print(time.time()-t)

            # def plot_landmarks():
            #     plt.imshow(img, interpolation='nearest')
            #     scs = []
            #     legendNames = []
            #     # for i in range(0,21):
            #     for i in range(0,68):
            #         sc = plt.scatter(landmarks[i,0], landmarks[i,1], s=10.0)
            #         scs.append(sc)
            #         legendNames.append("{}".format(i))
            #     plt.legend(scs, legendNames, scatterpoints=1, loc='best', ncol=4, fontsize=7)
            #     plt.show()

            # import vis
            # vis.show_landmarks(img, landmarks, wait=10)

            if crop_type == 'fullsize':
                fc = face_processing.FaceCrop(img,
                                              bbox=[0,0,img.shape[2], img.shape[1]],
                                              output_size=(img.shape[2], img.shape[1]))
                return img, landmarks, pose, fc

            if crop_type == 'loose':
                loose_bbox = self.get_loose_crop(img, landmarks=landmarks, detect_face=detect_face)

                fc = face_processing.FaceCrop(img, bbox=loose_bbox, output_size=size)
                loose_crop = fc.apply_to_image(with_hist_norm=False)
                if use_cache:
                    save_bbox(loose_bbox, cache_filepath)
                    io.imsave(cache_filepath, loose_crop)


        if loose_bbox is not None:
            cropper = face_processing.FaceCrop(img, bbox=loose_bbox, output_size=size,
                                               img_already_cropped=is_cached_crop)
        elif (landmarks is None or not landmarks.any()):
            assert (bb is not None)
            # Fall back to bounding box if no landmarks found
            # print('falling back to bounding box')
            # crop = face_processing.crop_by_bb(img, face_processing.scale_bb(bb, f=1.075), size=size)
            cropper = face_processing.FaceCrop(img, bbox=bb, img_already_cropped=is_cached_crop)
            # cropper = face_processing.FaceCrop(img, bbox=bb, img_already_cropped=False)
        else:
            cropper = face_processing.FaceCrop(img, landmarks=landmarks, img_already_cropped=is_cached_crop,
                                               align_face_orientation=aligned, output_size=size)

        try:
            crop = cropper.apply_to_image()
        except cv2.error:
            print('Could not crop image {} (load_fullsize=={}).'.format(filename, load_fullsize))
        landmarks, pose = cropper.apply_to_landmarks(landmarks, pose)

        if use_cache and not is_cached_crop:
            makedirs(cache_filepath)
            io.imsave(cache_filepath, crop)


        #
        # Image pre-processing
        #
        # fig, ax = plt.subplots(1,2)
        # ax[0].imshow(crop)

        # print('filtering...')
        # crop = cv2.bilateralFilter(crop, d=-1, sigmaColor=20, sigmaSpace=20)

        crop = np.minimum(crop, 255)

        if False:
            if utils.common.is_monochromatic_image(crop):
                # crop = exposure.equalize_adapthist(crop).astype(np.float32)

                crop = exposure.equalize_hist(crop) * 255
                crop = crop.astype(np.uint8)

                # p2, p98 = np.percentile(crop, (2, 98))
                # crop = exposure.rescale_intensity(crop, in_range=(p2, p98))
            # else:
            #     pil_img = Image.fromarray(crop)
            #     enhancer = ImageEnhance.Contrast(pil_img)
            #     pil_img = enhancer.enhance(0.9)
            #     crop = np.array(pil_img)

            # ax[1].imshow(crop)
            # plt.show()

        crop = cv2.medianBlur(crop, ksize=3)

        # crop = cv2.blur(crop, ksize=(3,3))



        # crop = cv2.blur(crop, ksize=(3,3))
        # print('done.')
        return crop, landmarks, pose, cropper
示例#10
0
    def train(self):
        if self.opts['train_reconstruction']:
            self.train_reconstruction()

        if self.opts['freeze_encoder']:
            self.model.freeze_encoder()

        loader = torch.utils.data.DataLoader(
            self.scene_graph_dataset,
            opts['batch_size'],
            num_workers=0,
            collate_fn=self.scene_graph_dataset.collate_fn)

        # baseline for moving average
        baseline = 0.
        alpha = self.opts['moving_avg_alpha']

        for e in range(self.opts['max_epochs']):
            # Set seeds for epoch
            rn = utils.set_seeds(e)

            with torch.no_grad():
                # Generate this epoch's data for task net
                i = 0

                # datadir
                out_dir = os.path.join(self.logdir, 'datagen')
                io.makedirs(out_dir)

                for idx, (g, x, m, adj) in tqdm(enumerate(loader),
                                                desc='Generating Data'):
                    x, adj = x.float().to(self.device), adj.float().to(
                        self.device)
                    # no sampling here

                    dec, dec_act = self.model(x, adj)
                    f = dec_act.cpu().numpy()
                    m = m.cpu().numpy()
                    g = self.generator.update(g, f, m)
                    r = self.generator.render(g)

                    for k in range(len(g)):
                        img, lbl = r[k]
                        out_img = os.path.join(out_dir,
                                               f'{str(i).zfill(6)}.jpg')
                        out_lbl = os.path.join(out_dir,
                                               f'{str(i).zfill(6)}.json')
                        io.write_img(img, out_img)
                        io.write_json(lbl, out_lbl)
                        i += 1

            # task accuracy
            acc = self.tasknet.train_from_dir(out_dir)
            # compute moving average
            if e > 0:
                baseline = alpha * acc + (1 - alpha) * baseline
            else:
                # initialize baseline to acc
                baseline = acc

            # Reset seeds to get exact same outputs
            rn2 = utils.set_seeds(e)
            for i in range(len(rn)):
                assert rn[i] == rn2[
                    i], 'Random numbers generated are different'

            # zero out gradients for first step
            self.optimizer.zero_grad()

            # Train dist matching and task loss
            for idx, (g, x, m, adj) in enumerate(loader):
                x, m, adj = (x.float().to(self.device),
                             m.float().to(self.device),
                             adj.float().to(self.device))

                dec, dec_act, log_probs = self.model(x, adj, m, sample=True)
                # sample here

                # get real images
                im_real = torch.from_numpy(
                    self.target_dataset.get_bunch_images(
                        self.opts['num_real_images'])).to(self.device)

                # get fake images
                im = self.renderer.render(g, dec_act, m)
                # different from generator.render, this
                # has a backward pass implemented and
                # it calls the generator.render function in
                # the forward pass

                if self.opts['dataset'] == 'mnist':
                    # add channel dimension and repeat 3 times for MNIST
                    im = im.unsqueeze(1).repeat(1, 3, 1, 1) / 255.
                    im_real = im_real.permute(0, 3, 1, 2).repeat(1, 3, 1,
                                                                 1) / 255.

                mmd = self.mmd(im_real, im) * self.opts['weight']['dist_mmd']

                if self.opts['use_task_loss']:
                    task_loss = -1 * torch.mean((acc - baseline) * log_probs)
                    loss = mmd + task_loss  # weighting is already done
                    loss.backward()
                else:
                    mmd.backward()
                    self.optimizer.step()
                    self.optimizer.zero_grad()

                if idx % self.opts['print_freq'] == 0:
                    print(f'[Dist] Step: {idx} MMD: {mmd.item()}')
                    if self.opts['use_task_loss']:
                        print(f'[Task] Reward: {acc}, Baseline: {baseline}')
                    # debug information
                    print(
                        f'[Feat] Step: {idx} {dec_act[0, 2, 15:].tolist()} {x[0, 2, 15:].tolist()}'
                    )
                    # To debug, this index is the loc_x, loc_y, yaw of the
                    # digit in MNIST

            if self.opts['use_task_loss']:
                self.optimizer.step()
                self.optimizer.zero_grad()

            # LR scheduler step
            self.lr_sched.step()
    def _run_batch(self, batch, eval=False, ds=None):
        time_dataloading = time.time() - self.iter_starttime
        time_proc_start = time.time()
        iter_stats = {"time_dataloading": time_dataloading}

        self.saae.zero_grad()
        self.saae.eval()

        input_images = (batch.target_images
                        if batch.target_images is not None else batch.images)

        with torch.set_grad_enabled(self.args.train_encoder):
            z_sample = self.saae.Q(input_images)

        iter_stats.update({"z_recon_mean": z_sample.mean().item()})

        #######################
        # Reconstruction phase
        #######################
        with torch.set_grad_enabled(self.args.train_encoder and not eval):
            X_recon = self.saae.P(z_sample)

        # calculate reconstruction error for debugging and reporting
        with torch.no_grad():
            iter_stats["loss_recon"] = aae_training.loss_recon(
                batch.images, X_recon)

        #######################
        # Landmark predictions
        #######################
        train_lmhead = not eval
        lm_preds_max = None
        with torch.set_grad_enabled(train_lmhead):
            self.saae.LMH.train(train_lmhead)
            X_lm_hm = self.saae.LMH(self.saae.P)
            if batch.lm_heatmaps is not None:
                loss_lms = F.mse_loss(batch.lm_heatmaps, X_lm_hm) * 100 * 3
                iter_stats.update({"loss_lms": loss_lms.item()})

            if eval or self._is_printout_iter(eval):
                # expensive, so only calculate when every N iterations
                # X_lm_hm = lmutils.decode_heatmap_blob(X_lm_hm)
                X_lm_hm = lmutils.smooth_heatmaps(X_lm_hm)
                lm_preds_max = self.saae.heatmaps_to_landmarks(X_lm_hm)

            if eval or self._is_printout_iter(eval):
                lm_gt = to_numpy(batch.landmarks)
                nmes = lmutils.calc_landmark_nme(
                    lm_gt,
                    lm_preds_max,
                    ocular_norm=self.args.ocular_norm,
                    image_size=self.args.input_size,
                )
                # nccs = lmutils.calc_landmark_ncc(batch.images, X_recon, lm_gt)
                iter_stats.update({"nmes": nmes})

        if train_lmhead:
            # if self.args.train_encoder:
            #     loss_lms = loss_lms * 80.0
            loss_lms.backward()
            self.optimizer_lm_head.step()
            if self.args.train_encoder:
                self.optimizer_E.step()
                # self.optimizer_G.step()

        # statistics
        iter_stats.update({
            "epoch": self.epoch,
            "timestamp": time.time(),
            "iter_time": time.time() - self.iter_starttime,
            "time_processing": time.time() - time_proc_start,
            "iter": self.iter_in_epoch,
            "total_iter": self.total_iter,
            "batch_size": len(batch),
        })
        self.iter_starttime = time.time()

        self.epoch_stats.append(iter_stats)

        batch_samples = {
            "batch": batch,
            "X_recon": X_recon,
            "X_lm_hm": X_lm_hm,
            "lm_preds_max": lm_preds_max,
            "ds": ds,
        }

        # print stats every N mini-batches
        if self._is_printout_iter(eval):
            self._print_iter_stats(
                self.epoch_stats[-self._print_interval(eval):])

            out_dir = os.path.join(
                cfg.REPORT_DIR,
                "landmark_predictions",
                self.session_name,
                str(self.epoch + 1),
            )
            io.makedirs(out_dir)

            lmvis.visualize_batch(
                batch.images,
                batch.landmarks,
                X_recon,
                X_lm_hm,
                lm_preds_max,
                self.all_landmarks,
                lm_heatmaps=batch.lm_heatmaps,
                target_images=batch.target_images,
                ds=ds,
                ocular_norm=self.args.ocular_norm,
                clean=False,
                overlay_heatmaps_input=False,
                overlay_heatmaps_recon=False,
                f=1.0,
                wait=self.wait,
                skeleton=self.skeleton,
            )
        return batch_samples
    def _print_epoch_summary(self,
                             epoch_stats,
                             epoch_starttime,
                             batch_samples,
                             eval=False):
        means = pd.DataFrame(epoch_stats).mean().to_dict()

        try:
            nmes = np.concatenate(
                [s["nmes"] for s in self.epoch_stats if "nmes" in s])
        except (KeyError, ValueError):
            nmes = np.zeros((1, 100))

        duration = int(time.time() - epoch_starttime)
        log.info("{}".format("-" * 100))
        str_stats = [
            "           "
            "l_rec={avg_loss_recon:.3f} "
            # 'ssim={avg_ssim:.3f} '
            # 'ssim_torch={avg_ssim_torch:.3f} '
            # 'z_mu={avg_z_recon_mean:.3f} '
            "l_lms={avg_loss_lms:.4f} "
            "err_lms={avg_err_lms_all:.2f} "
            "\tT: {time_epoch}"
        ][0]
        log.info(
            str_stats.format(
                iters_per_epoch=self.iters_per_epoch,
                avg_loss=means.get("loss", -1),
                avg_loss_recon=means.get("loss_recon", -1),
                avg_ssim=1.0 - means.get("ssim", -1),
                avg_ssim_torch=means.get("ssim_torch", -1),
                avg_loss_lms=means.get("loss_lms", -1),
                avg_loss_lms_cnn=means.get("loss_lms_cnn", -1),
                avg_err_lms_all=nmes[:, self.all_landmarks].mean(),
                avg_z_recon_mean=means.get("z_recon_mean", -1),
                t=means["iter_time"],
                t_data=means["time_dataloading"],
                t_proc=means["time_processing"],
                total_iter=self.total_iter + 1,
                total_time=str(
                    datetime.timedelta(seconds=self._training_time())),
                time_epoch=str(datetime.timedelta(seconds=duration)),
            ))
        try:
            recon_errors = np.concatenate(
                [stats["l1_recon_errors"] for stats in self.epoch_stats])
            rmse = np.sqrt(np.mean(recon_errors**2))
            log.info("RMSE: {} ".format(rmse))
        except KeyError:
            # print("no l1_recon_error")
            pass

        if self.args.eval and nmes is not None:
            # benchmark_mode = hasattr(self.args, 'benchmark')
            # self.print_eval_metrics(nmes, show=benchmark_mode)
            self.print_eval_metrics(nmes, show=False)

        # Saving output images
        batch = batch_samples["batch"]
        X_recon = batch_samples["X_recon"]
        X_lm_hm = batch_samples["X_lm_hm"]
        lm_preds_max = batch_samples["lm_preds_max"]
        ds = batch_samples["ds"]

        out_dir = os.path.join(
            cfg.REPORT_DIR,
            "landmark_predictions",
            self.session_name,
            str(self.epoch + 1),
            "eval" if eval else "train",
        )
        io.makedirs(out_dir)
        lmvis.visualize_batch(
            batch.images,
            batch.landmarks,
            X_recon,
            X_lm_hm,
            lm_preds_max,
            self.all_landmarks,
            lm_heatmaps=batch.lm_heatmaps,
            target_images=batch.target_images,
            ds=ds,
            ocular_norm=self.args.ocular_norm,
            clean=False,
            overlay_heatmaps_input=False,
            overlay_heatmaps_recon=False,
            f=1.0,
            wait=self.wait,
            skeleton=self.skeleton,
            out_dir=out_dir,
        )

        # Saving loss
        out_dir = os.path.join(cfg.REPORT_DIR, "losses", self.session_name)
        filename = "loss_eval.json" if eval else "loss_train.json"
        try:
            with open(os.path.join(out_dir, filename), "r") as outfile:
                data = json.load(outfile)
        except (json.decoder.JSONDecodeError, FileNotFoundError):
            data = dict()
        losses = {}
        losses.update(
            avg_loss=means.get("loss", -1),
            avg_loss_recon=means.get("loss_recon", -1),
            avg_ssim=1.0 - means.get("ssim", -1),
            avg_ssim_torch=means.get("ssim_torch", -1),
            avg_loss_lms=means.get("loss_lms", -1),
            avg_loss_lms_cnn=means.get("loss_lms_cnn", -1),
            avg_err_lms_all=nmes[:, self.all_landmarks].mean(),
            avg_z_recon_mean=means.get("z_recon_mean", -1),
        )
        data[str(self.epoch + 1)] = losses
        io.makedirs(out_dir)
        with open(os.path.join(out_dir, filename), "w") as outfile:
            json.dump(data, outfile)
示例#13
0
文件: BOA.py 项目: zqcui/MUBench
 def clone(self):
     io.makedirs(self.path)
     clone_command = "git clone --depth 1 {} . --quiet -c core.askpass=true".format(self.url)
     Shell.exec(clone_command, cwd=self.path, logger=self._logger)
示例#14
0
def export_pytorch_checkpoint_to_tf(model,
                                    ckpt_dir,
                                    bert_output_prefix="bert",
                                    appended_val_map=(),
                                    appended_tensors_to_transpose=()):
    """ Export PyTorch BERT model to TF Checkpoint

        Args:
            model (`nn.Module`) : The PyTorch model you want to save
            ckpt_dir (`str) : The directory of exporting checkpoint
            bert_output_prefix (`str`) : The prefix of BERT module, e.g. bert_pre_trained_model for EasyTransfer
            appended_val_map (`tuple`): A tuple of tuples, ( (PyTorch_var_name, Tensorflow_var_name), ...) )
            appended_tensors_to_transpose (`tuple`): A tuple of PyTorch tensor names you need to transpose
    """
    try:
        import os
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
        import re
        import numpy as np
        import tensorflow.compat.v1 as tf
        tf.disable_v2_behavior()
    except ImportError:
        logger.info(
            "Export a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
            "https://www.tensorflow.org/install/ for installation instructions."
        )
        raise RuntimeError

    def to_tf_var_name(name):
        for patt, repl in iter(var_map):
            name = name.replace(patt, repl)
        return name

    def create_tf_var(tensor, name, session):
        tf_dtype = tf.dtypes.as_dtype(tensor.dtype)
        tf_var = tf.get_variable(dtype=tf_dtype,
                                 shape=tensor.shape,
                                 name=name,
                                 initializer=tf.zeros_initializer())
        session.run(tf.variables_initializer([tf_var]))
        session.run(tf_var)
        return tf_var

    var_map = appended_val_map + (
        ("layer.", "layer_"), ("word_embeddings.weight", "word_embeddings"),
        ("position_embeddings.weight", "position_embeddings"),
        ("token_type_embeddings.weight", "token_type_embeddings"), (".", "/"),
        ("LayerNorm/weight", "LayerNorm/gamma"),
        ("LayerNorm/bias", "LayerNorm/beta"), ("/weight", "/kernel"))

    tensors_to_transpose = (
        "dense.weight", "attention.self.query", "attention.self.key",
        "attention.self.value") + appended_tensors_to_transpose

    if not os.path.isdir(ckpt_dir):
        io.makedirs(ckpt_dir)

    state_dict = model.state_dict()

    have_cls_predictions = False
    have_cls_seq_relationship = False
    for key in state_dict.keys():
        if key.startswith("cls.predictions"):
            have_cls_predictions = True
        if key.startswith("cls.seq_relationship"):
            have_cls_seq_relationship = True
    if not have_cls_predictions:
        state_dict["cls.predictions.output_bias"] = torch.zeros(
            model.config.vocab_size)
        state_dict["cls.predictions.transform.LayerNorm.beta"] = torch.zeros(
            model.config.hidden_size)
        state_dict["cls.predictions.transform.LayerNorm.gamma"] = torch.zeros(
            model.config.hidden_size)
        state_dict["cls.predictions.transform.dense.bias"] = torch.zeros(
            model.config.hidden_size)
        state_dict["cls.predictions.transform.dense.kernel"] = torch.zeros(
            (model.config.hidden_size, model.config.hidden_size))
    if not have_cls_seq_relationship:
        state_dict["cls.seq_relationship.output_weights"] = torch.zeros(
            (2, model.config.hidden_size))
        state_dict["cls.seq_relationship.output_bias"] = torch.zeros(2)

    tf.reset_default_graph()
    with tf.Session() as session:
        for var_name in state_dict:
            tf_name = to_tf_var_name(var_name)
            torch_tensor = state_dict[var_name].cpu().numpy()
            if var_name.startswith("bert.") or var_name.startswith("cls."):
                prefix = bert_output_prefix + "/" if bert_output_prefix else ""
            else:
                prefix = ""
            tf_name = prefix + tf_name
            if any([x in var_name for x in tensors_to_transpose]):
                torch_tensor = torch_tensor.T
            tf_var = create_tf_var(tensor=torch_tensor,
                                   name=tf_name,
                                   session=session)
            tf.keras.backend.set_value(tf_var, torch_tensor)
            # tf_weight = session.run(tf_var)
            # print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor)))
        create_tf_var(tensor=np.array(1), name="global_step", session=session)
        saver = tf.train.Saver(tf.trainable_variables())

        if "oss://" in ckpt_dir:
            saver.save(session, "model.ckpt")

            for fname in io.listdir("./"):
                if fname.startswith("model.ckpt"):
                    local_file = fname
                    oss_file = os.path.join(ckpt_dir, fname)
                    logger.info("uploading %s" % oss_file)
                    io.upload(local_file, oss_file)
        else:
            saver.save(session, os.path.join(ckpt_dir, "model.ckpt"))
    def _print_epoch_summary(self, epoch_stats, eval=False):

        means = pd.DataFrame(epoch_stats).mean().to_dict()
        try:
            ssim_scores = np.concatenate([
                stats["ssim"] for stats in self.epoch_stats if "ssim" in stats
            ])
        except:
            ssim_scores = np.array(0)
        duration = int(time.time() - self.epoch_starttime)

        log.info("{}".format("-" * 140))
        str_stats = [
            "Train:         "
            "l={avg_loss:.3f} "
            "l_rec={avg_loss_recon:.3f} "
            "l_ssim={avg_ssim_torch:.3f}({avg_ssim:.3f}) "
            "l_lmrec={avg_lms_recon:.3f} "
            "l_lmssim={avg_lms_ssim:.3f} "
            # 'l_lmcs={avg_lms_cs:.3f} '
            # 'l_lmncc={avg_lms_ncc:.3f} '
            "z_mu={avg_z_recon_mean:.3f} "
        ]
        str_stats[0] += [
            "l_D_z={avg_loss_D_z:.4f} "
            "l_E={avg_loss_E:.4f}  "
            "l_D={avg_loss_D:.4f} "
            "l_G={avg_loss_G:.4f} "
            "\tT: {epoch_time} ({total_time})"
        ][0]
        log.info(str_stats[0].format(
            iters_per_epoch=self.iters_per_epoch,
            avg_loss=means.get("loss", -1),
            avg_loss_recon=means.get("loss_recon", -1),
            avg_lms_recon=means.get("landmark_recon_errors", -1),
            avg_lms_ssim=means.get("landmark_ssim_scores", -1),
            avg_lms_ncc=means.get("landmark_ncc_errors", -1),
            avg_lms_cs=means.get("landmark_cs_errors", -1),
            avg_ssim=ssim_scores.mean(),
            avg_ssim_torch=means.get("ssim_torch", -1),
            avg_loss_E=means.get("loss_E", -1),
            avg_loss_D_z=means.get("loss_D_z", -1),
            avg_loss_D=means.get("loss_D", -1),
            avg_loss_G=means.get("loss_G", -1),
            avg_loss_D_real=means.get("err_real", -1),
            avg_loss_D_fake=means.get("err_fake", -1),
            avg_z_recon_mean=means.get("z_recon_mean", -1),
            t=means["iter_time"],
            t_data=means["time_dataloading"],
            t_proc=means["time_processing"],
            total_iter=self.total_iter + 1,
            total_time=str(datetime.timedelta(seconds=self._training_time())),
            totatl_time=str(
                datetime.timedelta(seconds=self.total_training_time())),
            epoch_time=str(datetime.timedelta(seconds=duration)),
        ))

        out_dir = os.path.join(cfg.REPORT_DIR, "losses", self.session_name)
        filename = "loss_eval.json" if eval else "loss_train.json"
        try:
            with open(os.path.join(out_dir, filename), "r") as outfile:
                data = json.load(outfile)
        except (json.decoder.JSONDecodeError, FileNotFoundError):
            data = dict()
        losses = {}
        losses.update(
            avg_loss=means.get("loss", -1),
            avg_loss_recon=means.get("loss_recon", -1),
            avg_lms_recon=means.get("landmark_recon_errors", -1),
            avg_lms_ssim=means.get("landmark_ssim_scores", -1),
            avg_lms_ncc=means.get("landmark_ncc_errors", -1),
            avg_lms_cs=means.get("landmark_cs_errors", -1),
            avg_ssim=ssim_scores.mean(),
            avg_ssim_torch=means.get("ssim_torch", -1),
            avg_loss_E=means.get("loss_E", -1),
            avg_loss_D_z=means.get("loss_D_z", -1),
            avg_loss_D=means.get("loss_D", -1),
            avg_loss_G=means.get("loss_G", -1),
            avg_loss_D_real=means.get("err_real", -1),
            avg_loss_D_fake=means.get("err_fake", -1),
            avg_z_recon_mean=means.get("z_recon_mean", -1),
        )
        data[self.epoch + 1] = losses
        io.makedirs(out_dir)
        with open(os.path.join(out_dir, filename), "w") as outfile:
            json.dump(data, outfile)