示例#1
0
    def debug_visualize(self, tensors, unique_suffix=None):
        """Visualize in TensorBoard.
        Args:
            tensors: Tensor or list of Tensors to be visualized.
        """
        # If only one tensor not in a list
        if not isinstance(tensors, list):
            tensors = [tensors]

        visuals = []
        for tensor in tensors:
            tensor = tensor.detach().to('cpu')
            tensor = tensor.numpy().copy()
            tensor_np = util.convert_image_from_tensor(tensor)
            visuals.append(tensor_np)

        if len(visuals) > 1:
            visuals_np = util.concat_images(visuals)
            visuals_pil = Image.fromarray(visuals_np)
        else:
            visuals_pil = Image.fromarray(visuals[0])

        title = 'debug'
        tag = f'{title}'
        if unique_suffix is not None:
            tag += '_{}'.format(unique_suffix)

        self.summary_writer.add_image(tag, np.uint8(visuals_np),
                                      self.global_step)
示例#2
0
    def log_status(self,
                   inputs,
                   targets,
                   probs,
                   loss,
                   save_preds=False,
                   force_visualize=False):
        """Log results and status of training."""

        batch_size = inputs.size(0)

        self.loss = loss.item()
        self.loss_meter.update(loss, batch_size)

        # Periodically write to the log and TensorBoard
        if self.global_step % self.steps_per_print == 0:

            # Write a header for the log entry
            duration = time() - self.train_start_time
            hours, rem = divmod(duration, 3600)
            minutes, seconds = divmod(rem, 60)

            message = f'[z-test][epoch: {self.epoch}, step: {self.global_step}, time: {int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}, loss: {self.loss_meter.avg:.3g}]'
            self.pbar.set_description(message)

            # Write all errors as scalars to the graph
            self._log_scalars({'loss': self.loss_meter.avg},
                              print_to_stdout=False)

        # Periodically visualize up to num_visuals training examples from the batch
        if self.global_step % self.steps_per_visual == 0 or force_visualize:
            self.visualize(probs, targets, probs, phase='z_test')

            if save_preds:
                probs_image_name = f'prediction-{epoch}.png'
                z_test_log_dir = os.path.join(self.log_dir, 'z_test')
                probs_image_path = os.path.join(z_test_log_dir,
                                                probs_image_name)

                probs = probs.detach().to('cpu')
                probs = probs.numpy().copy()
                probs_np = util.convert_image_from_tensor(probs)

                probs_pil = Image.fromarray(probs_np)
                probs_pil.save(probs_image_path)
示例#3
0
    def log_status(self,
                   inputs,
                   targets,
                   probs,
                   masked_probs,
                   masked_loss,
                   probs_eval,
                   masked_probs_eval,
                   obscured_probs_eval,
                   masked_loss_eval,
                   obscured_loss_eval,
                   full_loss_eval,
                   z_target,
                   z_probs,
                   z_loss,
                   save_preds=False,
                   force_visualize=False):
        """Log results and status of training."""

        batch_size = inputs.size(0)

        masked_loss = masked_loss.item()
        masked_loss_eval = masked_loss_eval.item()
        obscured_loss_eval = obscured_loss_eval.item()
        full_loss_eval = full_loss_eval.item()
        z_loss = z_loss.item()

        self.masked_loss_meter.update(masked_loss, batch_size)
        self.masked_loss_eval_meter.update(masked_loss_eval, batch_size)
        self.obscured_loss_eval_meter.update(obscured_loss_eval, batch_size)
        self.full_loss_eval_meter.update(full_loss_eval, batch_size)
        self.z_loss_meter.update(z_loss, batch_size)

        # Periodically write to the log and TensorBoard
        if self.global_step % self.steps_per_print == 0:

            # Write a header for the log entry
            duration = time() - self.train_start_time
            hours, rem = divmod(duration, 3600)
            minutes, seconds = divmod(rem, 60)

            message = f'[epoch: {self.epoch}, step: {self.global_step}, time: {int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}, masked loss (train): {self.masked_loss_meter.avg:.3g}, masked loss (eval): {self.masked_loss_eval_meter.avg:.3g}, obscured_loss: {self.obscured_loss_eval_meter.avg:.3g}, loss: {self.full_loss_eval_meter.avg:.3g}, z-loss: {self.z_loss_meter.avg:.3g}]'
            self.pbar.set_description(message)

            # Write all errors as scalars to the graph
            self._log_scalars({'loss_masked': self.masked_loss_meter.avg},
                              print_to_stdout=False)
            self._log_scalars(
                {'loss_masked-eval': self.masked_loss_eval_meter.avg},
                print_to_stdout=False)
            self._log_scalars(
                {'loss_obscured': self.obscured_loss_eval_meter.avg},
                print_to_stdout=False)
            self._log_scalars({'loss_all': self.full_loss_eval_meter.avg},
                              print_to_stdout=False)
            self._log_scalars({'z_loss': self.z_loss_meter.avg},
                              print_to_stdout=False)

        # Periodically visualize up to num_visuals training examples from the batch
        if self.global_step % self.steps_per_visual == 0 or force_visualize:
            # Does not make sense to show masked or obscured probs... since not image size anymore
            self.visualize(probs, targets, obscured_probs_eval, phase='train')
            self.visualize(probs_eval,
                           targets,
                           obscured_probs_eval,
                           phase='eval')
            self.visualize(z_probs, z_target, z_probs, phase='z-test')

            if save_preds:
                probs_image_name = f'prediction-{self.global_step}.png'
                probs_image_path = os.path.join(self.log_dir, probs_image_name)

                probs = probs.detach().to('cpu')
                probs = probs.numpy().copy()
                probs_np = util.convert_image_from_tensor(probs)

                probs_pil = Image.fromarray(probs_np)
                probs_pil.save(probs_image_path)
示例#4
0
def train_inverted_net(args):
    # Start by training an external model on samples of G(z) -> z inversion
    model = util.get_invert_model(args)

    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    print(f'{args.invert_model} num params {count_parameters(model)}')

    generator = util.get_model(args)
    if generator is not None:
        generator = nn.DataParallel(generator, args.gpu_ids)
        generator = generator.to(args.device)
        print(f'{args.model} num params {count_parameters(generator)}')
    else:
        # Load saved pairings (ProGAN/StyleGAN)
        pairing_dir = '/deep/group/gen-eval/model-training/src/GAN_models/stylegan'
        pairing_path = f'{pairing_dir}/otavio_sampled_output/pairing.csv'
        pairings = pd.read_csv(pairing_path)

        num_pairings = len(pairings)
        noise_targets = pairings['noise']
        image_inputs = pairings['image']

    if 'BigGAN' in args.model:
        class_vector = one_hot_from_int(207, batch_size=args.batch_size)
        class_vector = torch.from_numpy(class_vector)
        class_vector = class_vector.cuda()

    # TODO: remove bc cant use gpu in laoder i don't think
    #loader = get_loader(args, phase='invert')

    #logger = TestLogger(args)
    #logger.log_hparams(args)

    criterion = torch.nn.MSELoss().to(args.device)
    optimizer = util.get_optimizer(model.parameters(), args)

    for i in range(args.num_invert_epochs):
        if generator is not None:
            noise_target = util.get_noise(args)

            image_input = generator.forward(noise_target).float()
            image_input = (image_input + 1.) / 2.
        else:
            # TODO: make into loader
            idx = i % num_pairings
            noise_target = np.load(f'{pairing_dir}/{noise_targets[idx]}')
            noise_target = torch.from_numpy(noise_target).float()
            print(f'noise target shape {noise_target.shape}')

            image_input = np.array(
                Image.open(f'{pairing_dir}/{image_inputs[idx]}'))
            image_input = torch.from_numpy(image_input / 255.)
            image_input = image_input.float().unsqueeze(0)
            image_input = image_input.permute(0, 3, 1, 2)

        noise_target = noise_target.cuda()
        image_input = image_input.cuda()

        with torch.set_grad_enabled(True):
            probs = model.forward(image_input)

            loss = torch.zeros(1, requires_grad=True).to(args.device)
            loss = criterion(probs, noise_target)
            print(f'iter {i}: loss = {loss}')

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        if i % 1 == 0:
            corres_image_input = image_input.detach().cpu()
            corres_np = util.convert_image_from_tensor(corres_image_input)

            # Run check - saving image
            if 'BigGAN' in args.model:
                predicted_image = generator.forward(probs, class_vector,
                                                    truncation).float()
            else:
                if generator is not None:
                    predicted_image = generator.forward(probs).float()

                    predicted_image = predicted_image.detach().cpu()
                    predicted_image = (predicted_image + 1) / 2.
                    predicted_np = util.convert_image_from_tensor(
                        predicted_image)

                    if len(predicted_np.shape) == 4:
                        predicted_np = predicted_np[0]
                        corres_np = corres_np[0]
                    visuals = util.concat_images([predicted_np, corres_np])
                    visuals_pil = Image.fromarray(visuals)
                    timestamp = datetime.now().strftime('%b%d_%H%M%S%f')
                    visuals_image_dir = f'predicted_inversion_images/{args.model}'
                    os.makedirs(visuals_image_dir, exist_ok=True)
                    visuals_image_path = f'{visuals_image_dir}/{timestamp}_{i}.png'
                    visuals_pil.save(visuals_image_path)

                    print(f'Saved {visuals_image_path}')
                else:
                    # Save noise vector - do forward separately in tf env
                    probs = probs.detach().cpu().numpy()
                    pred_noise_dir = f'predicted_inversion_noise/{args.model}'
                    os.makedirs(pred_noise_dir, exist_ok=True)

                    pred_noise_path = f'{pred_noise_dir}/{args.model}_noise_{i}.npy'
                    np.save(pred_noise_path, probs)

                    print(f'Saved {pred_noise_path}')

        if i % 1 == 0:
            corres_image_input = image_input.detach().cpu()
            corres_np = util.convert_image_from_tensor(corres_image_input)

            if len(corres_np.shape) == 4:
                corres_np = corres_np[0]

            corres_pil = Image.fromarray(corres_np)
            timestamp = datetime.now().strftime('%b%d_%H%M%S%f')
            corres_image_dir = f'generated_images/{args.model}'
            os.makedirs(corres_image_dir, exist_ok=True)
            corres_image_path = f'{corres_image_dir}/{timestamp}_{i}.png'
            corres_pil.save(corres_image_path)

    # saver = ModelSaver(args)
    global_step = args.num_invert_epochs
    ckpt_dict = {
        'ckpt_info': {
            'global_step': global_step
        },
        'model_name': model.module.__class__.__name__,
        'model_args': model.module.args_dict(),
        'model_state': model.to('cpu').state_dict(),
        'optimizer': optimizer.state_dict(),
    }

    ckpt_dir = os.path.join(args.save_dir, f'{args.model}')
    os.makedirs(ckpt_dir, exist_ok=True)
    ckpt_path = os.path.join(
        ckpt_dir, f'{args.invert_model}_step_{global_step}.pth.tar')
    torch.save(ckpt_dict, ckpt_path)
    print(f'Saved model to {ckpt_path}')

    import pdb
    pdb.set_trace()

    return model
示例#5
0
    def log_status(self,
                   masked_probs,
                   masked_loss,
                   masked_test_target,
                   full_probs,
                   full_loss,
                   full_test_target,
                   obscured_probs,
                   obscured_loss,
                   obscured_test_target,
                   save_preds=False,
                   force_visualize=False):
        """Log results and status of z test."""

        batch_size = full_probs.size(0)

        masked_loss = masked_loss.item()
        full_loss = full_loss.item()
        obscured_loss = obscured_loss.item()

        self.masked_loss_meter.update(masked_loss, batch_size)
        self.full_loss_meter.update(full_loss, batch_size)
        self.obscured_loss_meter.update(obscured_loss, batch_size)

        # Periodically write to the log and TensorBoard
        if self.global_step % self.steps_per_print == 0:

            # Write a header for the log entry
            duration = time() - self.train_start_time
            hours, rem = divmod(duration, 3600)
            minutes, seconds = divmod(rem, 60)

            message = f'[epoch: {self.epoch}, step: {self.global_step}, time: {int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}, masked loss: {self.masked_loss_meter.avg:.3g}, full loss: {self.full_loss_meter.avg:.3g}, obscured_loss: {self.obscured_loss_meter.avg:.3g}]'
            self.pbar.set_description(message)
            self.pbar.update(1)

            # Write all errors as scalars to the graph
            self._log_scalars({'loss': self.masked_loss_meter.avg},
                              print_to_stdout=False)
            self._log_scalars({'loss_all': self.full_loss_meter.avg},
                              print_to_stdout=False)
            self._log_scalars({'loss_obscured': self.obscured_loss_meter.avg},
                              print_to_stdout=False)

        # Periodically visualize up to num_visuals training examples from the batch
        if self.global_step % self.steps_per_visual == 0 or force_visualize:

            self.visualize(full_probs,
                           full_test_target,
                           obscured_probs,
                           phase='test')

            if save_preds:
                probs_image_name = f'z-test-pred-{self.global_step}.png'
                probs_image_path = os.path.join(self.log_dir, probs_image_name)

                full_probs = full_probs.detach().to('cpu')
                full_probs = full_probs.numpy().copy()
                full_probs_np = util.convert_image_from_tensor(full_probs)

                full_probs_pil = Image.fromarray(full_probs_np)
                full_probs_pil.save(probs_image_path)
示例#6
0
    def visualize(self,
                  probs_batch,
                  targets_batch,
                  obscured_probs_batch,
                  phase,
                  unique_suffix=None,
                  make_separate_prediction_img=False):
        """Visualize predictions and targets in TensorBoard.
        Args:
            probs_batch: Probabilities outputted by the model, in minibatch.
            targets_batch: Target labels for the inputs, in minibatch.
            phase: One of 'train', 'z-test' (during training), or 'test' (z-test eval alone).
            unique_suffix: A unique suffix to append to every image title. Allows
              for displaying all visualizations separately on TensorBoard.
        Returns:
            Number of examples visualized to TensorBoard.
        """

        probs_batch = probs_batch.detach().to('cpu')
        probs_batch = probs_batch.numpy().copy()

        targets_batch = targets_batch.detach().to('cpu')
        targets_batch = targets_batch.numpy().copy()

        obscured_probs_batch = obscured_probs_batch.detach().to('cpu')
        obscured_probs_batch = obscured_probs_batch.numpy().copy()

        batch_size = targets_batch.shape[
            0]  # Do not use self.batch_size -- this is local, handling edge cases
        visual_indices = random.sample(range(batch_size),
                                       min(self.num_visuals, batch_size))
        for i in visual_indices:
            probs = probs_batch[i]
            targets = targets_batch[i]
            obscured_probs = obscured_probs_batch[i]

            probs_np = util.convert_image_from_tensor(probs)
            targets_np = util.convert_image_from_tensor(targets)
            obscured_probs_np = util.convert_image_from_tensor(obscured_probs)

            if phase == "z-test":
                visuals = [probs_np, targets_np]

                title = 'target_pred'
                visuals_image_name = f'{title}-{self.global_step}-{i}.png'
                log_dir_z_test = os.path.join(self.log_dir, 'z_test')
                os.makedirs(log_dir_z_test, exist_ok=True)
                visuals_image_path = os.path.join(log_dir_z_test,
                                                  visuals_image_name)
            else:
                #abs_diff = np.abs(targets_np - probs_np)
                from PIL import ImageChops
                targets_pil = Image.fromarray(targets_np)
                probs_pil = Image.fromarray(probs_np)
                abs_diff = ImageChops.difference(targets_pil, probs_pil)
                abs_diff = np.array(abs_diff)

                visuals = [probs_np, targets_np, abs_diff, obscured_probs_np]

                title = 'pred_target_diff_obscured'
                visuals_image_name = f'{title}-{self.global_step}-{i}.png'
                log_dir_mask = os.path.join(self.log_dir, 'mask')
                os.makedirs(log_dir_mask, exist_ok=True)
                visuals_image_path = os.path.join(log_dir_mask,
                                                  visuals_image_name)

            visuals_np = util.concat_images(visuals)
            visuals_pil = Image.fromarray(visuals_np)

            if make_separate_prediction_img:
                visuals_pil.save(visuals_image_path)

            tag = f'{phase}/{title}'
            if unique_suffix is not None:
                tag += '_{}'.format(unique_suffix)

            # If channel dimension is not first, then move to front
            if visuals_np.shape[0] != 3 and visuals_np.shape[2] == 3:
                visuals_np = np.transpose(visuals_np, (2, 0, 1))
            self.summary_writer.add_image(tag, np.uint8(visuals_np),
                                          self.global_step)