示例#1
0
class TensorboardSummary(object):
    def __init__(self, directory):
        self.directory = directory
        self.writer = SummaryWriter(log_dir=os.path.join(self.directory))

    def add_scalar(self, *args):
        self.writer.add_scalar(*args)

    def visualize_video(self, opt, global_step, video, name):

        video_transpose = video.permute(0, 2, 1, 3, 4)  # BxTxCxHxW
        video_reshaped = video_transpose.flatten(0, 1)  # (B+T)xCxHxW

        # image_range = opt.td #+ opt.num_targets
        image_range = video.shape[2]

        grid_image = make_grid(
            video_reshaped[:3 * image_range, :, :, :].clone().cpu().data,
            image_range,
            normalize=True)
        self.writer.add_image(
            'Video/Scale {}/{}_unfold'.format(opt.scale_idx, name), grid_image,
            global_step)
        norm_range(video_transpose)
        self.writer.add_video('Video/Scale {}/{}'.format(opt.scale_idx, name),
                              video_transpose[:3], global_step)

    def visualize_image(self, opt, global_step, ןimages, name):
        grid_image = make_grid(ןimages[:3, :, :, :].clone().cpu().data,
                               3,
                               normalize=True)
        self.writer.add_image('Image/Scale {}/{}'.format(opt.scale_idx, name),
                              grid_image, global_step)
示例#2
0
class Logger:
    def __init__(self, log_dir, summary_writer=None):
        self.writer = SummaryWriter(log_dir, flush_secs=1, max_queue=20)
        self.step = 0

    def add_graph(self, model=None, input: tuple = None):
        self.writer.add_graph(model, input)
        self.flush()

    def log_scalar(self, scalar, name, step_):
        self.writer.add_scalar('{}'.format(name), scalar, step_)

    def log_scalars(self, scalar_dict, step, phase='Train_'):
        """Will log all scalars in the same plot."""
        self.writer.add_scalars(phase, scalar_dict, step)

    def log_video(self, video_frames, name, step, fps=10):
        assert len(
            video_frames.shape
        ) == 5, "Need [N, T, C, H, W] input tensor for video logging!"
        self.writer.add_video('{}'.format(name), video_frames, step, fps=fps)

    def flush(self):
        self.writer.flush()

    def close(self):
        self.writer.close()
示例#3
0
class Tensorboard:
    def __init__(self, config: ConfigFactory) -> None:
        self.config = config
        self.writer = SummaryWriter(self.config.tensorboard_path())

    def write_scalar(self, title, value, iteration) -> None:
        self.writer.add_scalar(title, value, iteration)

    def write_video(self, title, value, iteration) -> None:
        self.writer.add_video(title,
                              value,
                              global_step=iteration,
                              fps=self.config.fps)

    def write_image(self, title, value, iteration) -> None:
        self.writer.add_image(title,
                              value,
                              global_step=iteration,
                              dataformats='CHW')

    def write_histogram(self, title, value, iteration) -> None:
        self.writer.add_histogram(title, value, iteration)

    def write_embedding(self, all_embeddings, metadata, images) -> None:
        self.writer.add_embedding(all_embeddings,
                                  metadata=metadata,
                                  label_img=images)

    def write_embedding_no_labels(self, all_embeddings, images) -> None:
        self.writer.add_embedding(all_embeddings, label_img=images)
示例#4
0
class TensorboardWriter:
    def __init__(self, log_dir: str, *args: Any, **kwargs: Any):
        r"""A Wrapper for tensorboard SummaryWriter. It creates a dummy writer
        when log_dir is empty string or None. It also has functionality that
        generates tb video directly from numpy images.

        Args:
            log_dir: Save directory location. Will not write to disk if
            log_dir is an empty string.
            *args: Additional positional args for SummaryWriter
            **kwargs: Additional keyword args for SummaryWriter
        """
        self.writer = None
        if log_dir is not None and len(log_dir) > 0:
            self.writer = SummaryWriter(log_dir, *args, **kwargs)

    def __getattr__(self, item):
        if self.writer:
            return self.writer.__getattribute__(item)
        else:
            return lambda *args, **kwargs: None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.writer:
            self.writer.close()

    def add_video_from_np_images(self,
                                 video_name: str,
                                 step_idx: int,
                                 images: np.ndarray,
                                 fps: int = 10) -> None:
        r"""Write video into tensorboard from images frames.

        Args:
            video_name: name of video string.
            step_idx: int of checkpoint index to be displayed.
            images: list of n frames. Each frame is a np.ndarray of shape.
            fps: frame per second for output video.

        Returns:
            None.
        """
        if not self.writer:
            return
        # initial shape of np.ndarray list: N * (H, W, 3)
        frame_tensors = [
            torch.from_numpy(np_arr).unsqueeze(0) for np_arr in images
        ]
        video_tensor = torch.cat(tuple(frame_tensors))
        video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0)
        # final shape of video tensor: (1, n, 3, H, W)
        self.writer.add_video(video_name,
                              video_tensor,
                              fps=fps,
                              global_step=step_idx)
示例#5
0
class Logger(object):
    def __init__(self, log_dir, use_tb=False, config='rl'):
        self._log_dir = log_dir
        if use_tb:
            tb_dir = os.path.join(log_dir, 'tb')
            if os.path.exists(tb_dir):
                shutil.rmtree(tb_dir)
            self._sw = SummaryWriter(tb_dir)
        else:
            self._sw = None
        self._train_mg = MetersGroup(os.path.join(log_dir, 'train.log'),
                                     formating=FORMAT_CONFIG[config]['train'])
        self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval.log'),
                                    formating=FORMAT_CONFIG[config]['eval'])

    def _try_sw_log(self, key, value, step):
        if self._sw is not None:
            self._sw.add_scalar(key, value, step)

    def _try_sw_log_video(self, key, frames, step):
        if self._sw is not None:
            frames = torch.from_numpy(np.array(frames))
            frames = frames.unsqueeze(0)
            self._sw.add_video(key, frames, step, fps=30)

    def _try_sw_log_histogram(self, key, histogram, step):
        if self._sw is not None:
            self._sw.add_histogram(key, histogram, step)

    def log(self, key, value, step, n=1):
        assert key.startswith('train') or key.startswith('eval')
        if type(value) == torch.Tensor:
            value = value.item()
        self._try_sw_log(key, value / n, step)
        mg = self._train_mg if key.startswith('train') else self._eval_mg
        mg.log(key, value, n)

    def log_param(self, key, param, step):
        self.log_histogram(key + '_w', param.weight.data, step)
        if hasattr(param.weight, 'grad') and param.weight.grad is not None:
            self.log_histogram(key + '_w_g', param.weight.grad.data, step)
        if hasattr(param, 'bias'):
            self.log_histogram(key + '_b', param.bias.data, step)
            if hasattr(param.bias, 'grad') and param.bias.grad is not None:
                self.log_histogram(key + '_b_g', param.bias.grad.data, step)

    def log_video(self, key, frames, step):
        assert key.startswith('train') or key.startswith('eval')
        self._try_sw_log_video(key, frames, step)

    def log_histogram(self, key, histogram, step):
        assert key.startswith('train') or key.startswith('eval')
        self._try_sw_log_histogram(key, histogram, step)

    def dump(self, step):
        self._train_mg.dump(step, 'train')
        self._eval_mg.dump(step, 'eval')
示例#6
0
class TensorBoardOutputFormat(KVWriter):
    def __init__(self, folder: str):
        """
        Dumps key/value pairs into TensorBoard's numeric format.

        :param folder: the folder to write the log to
        """
        assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
        self.writer = SummaryWriter(log_dir=folder)

    def write(self,
              key_values: Dict[str, Any],
              key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
              step: int = 0) -> None:

        for (key, value), (_, excluded) in zip(sorted(key_values.items()),
                                               sorted(key_excluded.items())):

            if excluded is not None and "tensorboard" in excluded:
                continue

            if isinstance(value, np.ScalarType):
                if isinstance(value, str):
                    # str is considered a np.ScalarType
                    self.writer.add_text(key, value, step)
                else:
                    self.writer.add_scalar(key, value, step)

            if isinstance(value, th.Tensor):
                self.writer.add_histogram(key, value, step)

            if isinstance(value, Video):
                self.writer.add_video(key, value.frames, step, value.fps)

            if isinstance(value, Figure):
                self.writer.add_figure(key,
                                       value.figure,
                                       step,
                                       close=value.close)

            if isinstance(value, Image):
                self.writer.add_image(key,
                                      value.image,
                                      step,
                                      dataformats=value.dataformats)

        # Flush the output to the file
        self.writer.flush()

    def close(self) -> None:
        """
        closes the file
        """
        if self.writer:
            self.writer.close()
            self.writer = None
示例#7
0
def main(args):

    # write into tensorboard
    log_path = os.path.join('demos', args.dataset + '/log')
    vid_path = os.path.join('demos', args.dataset + '/vids')

    os.makedirs(log_path, exist_ok=True)
    os.makedirs(vid_path, exist_ok=True)
    writer = SummaryWriter(log_path)

    device = torch.device("cuda:0")

    G = Generator(args.dim_z, args.dim_a, args.nclasses, args.ch).to(device)
    G = nn.DataParallel(G)
    G.load_state_dict(torch.load(args.model_path))

    transform = torchvision.transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    dataset = MUG_test(args.data_path, transform=transform)

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=args.batch_size,
                                             num_workers=args.num_workers,
                                             shuffle=False,
                                             pin_memory=True)

    with torch.no_grad():

        G.eval()

        img = next(iter(dataloader))

        bs = img.size(0)
        nclasses = args.nclasses

        z = torch.randn(bs, args.dim_z).to(device)

        for i in range(nclasses):
            y = torch.zeros(bs, nclasses).to(device)
            y[:, i] = 1.0
            vid_gen = G(img, z, y)

            vid_gen = vid_gen.transpose(2, 1)
            vid_gen = ((vid_gen - vid_gen.min()) /
                       (vid_gen.max() - vid_gen.min())).data

            writer.add_video(tag='vid_cat_%d' % i, vid_tensor=vid_gen)
            writer.flush()

            # save videos
            print('==> saving videos')
            save_videos(vid_path, vid_gen, bs, i)
示例#8
0
class TensorLogger(object):
    # creating file in given logdir ... defaults to ./runs/
    def __init__(self, _logdir='./runs/'):
        if not os.path.exists(_logdir):
            os.makedirs(_logdir)

        self.writer = SummaryWriter(log_dir=_logdir)

    # adding scalar value to tb file
    def scalar_summary(self, _tag, _value, _step):
        self.writer.add_scalar(_tag, _value, _step)

    # adding image value to tb file
    def image_summary(self, _tag, _image, _step, _format='CHW'):
        """
            default dataformat for image tensor is (3, H, W)
            can be changed to
                : (1, H, W) - dataformat = CHW
                : (H, W, 3) - dataformat HWC
                : (H, W) - datformat HW
        """
        #
        self.writer.add_image(_tag, _image, _step, dataformat=_format)

    # adding matplotlib figure to tb file
    def figure_summary(self, _tag, _figure, _step):
        self.writer.add_figure(_tag, _figure, _step)

    # adding video to tb file
    def video_summary(self, _tag, _video, _step, _fps=4):
        """
            default torch fps is 4, can be changed
            also, video tensor should be of format (N, T, C, H, W)
            values should be between [0,255] for unit8 and [0,1] for float32
        """
        # default value of video fps is 4 - can be changed
        self.writer.add_video(_tag, _video, _step, _fps)

    # adding audio to tb file
    def audio_summary(self, _tag, _sound, _step, _sampleRate=44100):
        """
            default torch sample rate is 44100, can be changed
            also, sound tensor should be of format (1, L)
            values should lie between [-1,1]
        """
        self.writer.add_audio(_tag, _sound, _step, sample_rate=_sampleRate)

    # adding text to tb file
    def text_summary(self, _tag, _textString, _step):
        self.writer.add_text(_tag, _textString, _step)

    # adding histograms to tb file
    def histogram_summary(self, _tag, _histogram, _step, _bins='tensorflow'):
        self.writer.add_histrogram(_tag, _histogram, _step, bins=_bins)
示例#9
0
class VisualizerTensorboard:
    def __init__(self, opts):
        self.dtype = {}
        self.iteration = 1
        self.writer = SummaryWriter(opts.logs_dir)

    def register(self, modules):
        # here modules are assumed to be a dictionary
        for key in modules:
            self.dtype[key] = modules[key]['dtype']

    def update(self, modules):
        for key, value in modules:
            if self.dtype[key] == 'scalar':
                self.writer.add_scalar(key, value, self.iteration)
            elif self.dtype[key] == 'scalars':
                self.writer.add_scalars(key, value, self.iteration)
            elif self.dtype[key] == 'histogram':
                self.writer.add_histogram(key, value, self.iteration)
            elif self.dtype[key] == 'image':
                self.writer.add_image(key, value, self.iteration)
            elif self.dtype[key] == 'images':
                self.writer.add_images(key, value, self.iteration)
            elif self.dtype[key] == 'figure':
                self.writer.add_figure(key, value, self.iteration)
            elif self.dtype[key] == 'video':
                self.writer.add_video(key, value, self.iteration)
            elif self.dtype[key] == 'audio':
                self.writer.add_audio(key, value, self.iteration)
            elif self.dtype[key] == 'text':
                self.writer.add_text(key, value, self.iteration)
            elif self.dtype[key] == 'embedding':
                self.writer.add_embedding(key, value, self.iteration)
            elif self.dtype[key] == 'pr_curve':
                self.writer.pr_curve(key, value['labels'],
                                     value['predictions'], self.iteration)
            elif self.dtype[key] == 'mesh':
                self.writer.add_audio(key, value, self.iteration)
            elif self.dtype[key] == 'hparams':
                self.writer.add_hparams(key, value['hparam_dict'],
                                        value['metric_dict'], self.iteration)
            else:
                raise Exception(
                    'Data type not supported, please update the visualizer plugin and rerun !!'
                )

        self.iteration = self.iteration + 1
示例#10
0
def main():

    args = cfg.parse_args()

    # write into tensorboard
    log_path = os.path.join(args.demo_path, args.demo_name + '/log')
    vid_path = os.path.join(args.demo_path, args.demo_name + '/vids')

    if not os.path.exists(log_path) and not os.path.exists(vid_path):
        os.makedirs(log_path)
        os.makedirs(vid_path)
    writer = SummaryWriter(log_path)

    device = torch.device("cuda:0")

    G = Generator().to(device)
    G = nn.DataParallel(G)
    G.load_state_dict(torch.load(args.model_path))

    with torch.no_grad():
        G.eval()

        za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device)
        zm = torch.randn(args.n_zm_test, args.d_zm, 1, 1, 1).to(device)

        n_za = za.size(0)
        n_zm = zm.size(0)
        za = za.unsqueeze(1).repeat(1, n_zm, 1, 1, 1, 1).contiguous().view(
            n_za * n_zm, -1, 1, 1, 1)
        zm = zm.repeat(n_za, 1, 1, 1, 1)

        vid_fake = G(za, zm)

        vid_fake = vid_fake.transpose(2, 1)  # bs x 16 x 3 x 64 x 64
        vid_fake = ((vid_fake - vid_fake.min()) /
                    (vid_fake.max() - vid_fake.min())).data

        writer.add_video(tag='generated_videos',
                         global_step=1,
                         vid_tensor=vid_fake)
        writer.flush()

        # save into videos
        print('==> saving videos...')
        save_videos(vid_path, vid_fake, n_za, n_zm)

    return
示例#11
0
class TensorboardSummary(object):
    def __init__(self, directory, neptune_exp=None):
        self.directory = directory
        self.writer = SummaryWriter(log_dir=os.path.join(self.directory))
        self.neptune_exp = neptune_exp
        self.to_image = transforms.ToPILImage()

    def add_scalar(self, log_name, value, index):
        if self.neptune_exp:
            self.neptune_exp.log_metric(log_name, index, value)
        else:
            self.writer.add_scalar(log_name, value, index)

    def visualize_video(self, opt, global_step, video, name):

        video_transpose = video.permute(0, 2, 1, 3, 4)  # BxTxCxHxW
        video_reshaped = video_transpose.flatten(0, 1)  # (B+T)xCxHxW

        # image_range = opt.td #+ opt.num_targets
        image_range = video.shape[2]

        grid_image = make_grid(
            video_reshaped[:3 * image_range, :, :, :].clone().cpu().data,
            image_range,
            normalize=True)
        self.writer.add_image(
            'Video/Scale {}/{}_unfold'.format(opt.scale_idx, name), grid_image,
            global_step)
        norm_range(video_transpose)
        self.writer.add_video('Video/Scale {}/{}'.format(opt.scale_idx, name),
                              video_transpose[:3], global_step)

    def visualize_image(self, opt, global_step, ןimages, name):
        grid_image = make_grid(ןimages[:3, :, :, :].clone().cpu().data,
                               3,
                               normalize=True)
        img_name = 'Image/Scale {}/{}'.format(opt.scale_idx, name)
        if self.neptune_exp:
            self.neptune_exp.log_image(img_name,
                                       global_step,
                                       y=self.to_image(grid_image))
        else:
            self.writer.add_image(img_name, grid_image, global_step)
示例#12
0
    def train_vis(self, model, dataset, writer: SummaryWriter, global_step,
                  indices, device, cond_steps, fg_sample, bg_sample, num_gen):

        grid, gif = self.show_tracking(model, dataset, indices, device)
        writer.add_image('tracking/grid', grid, global_step)
        for i in range(len(gif)):
            writer.add_video(f'tracking/video_{i}', gif[i:i + 1], global_step)

        # Generation
        grid, gif = self.show_generation(model,
                                         dataset,
                                         indices,
                                         device,
                                         cond_steps,
                                         fg_sample=fg_sample,
                                         bg_sample=bg_sample,
                                         num=num_gen)
        writer.add_image('generation/grid', grid, global_step)
        for i in range(len(gif)):
            writer.add_video(f'generation/video_{i}', gif[i:i + 1],
                             global_step)
示例#13
0
    def train(self, lens, texts, movies, pp=0, lp=20):
        """
        lp: log period
        pp: print period (0 - no print to stdout)
        """
        writer = SummaryWriter()
        writer.add_video("Real Clips", to_video(movies))
        device = next(self.generator.parameters()).device

        time_per_epoch = -time.time()
        for epoch in range(self.num_epochs):
            for No, batch in enumerate(self.vloader):
                labels = batch['label'].to(device, non_blocking=True)
                videos = batch['video'].to(device, non_blocking=True)
                senlen = batch['slens'].to(device, non_blocking=True)
                self.passBatchThroughNetwork(labels, videos, senlen)

            if pp and epoch % pp == 0:
                time_per_epoch += time.time()
                print(f'Epoch {epoch}/{self.num_epochs}')
                for k, v in self.logs.items():
                    print("\t%s:\t%5.4f" % (k, v / (No + 1)))
                    self.logs[k] = v / (No + 1)
                print('Completed in %.f s' % time_per_epoch)
                time_per_epoch = -time.time()

            if epoch % lp == 0:
                self.generator.eval()
                with torch.no_grad():
                    condition = self.encoder(texts, lens)
                    movies = self.generator(condition)
                writer.add_scalars('Loss', self.logs, epoch)
                writer.add_video('Fakes', to_video(movies), epoch)
                self.generator.train()
            self.logs = dict.fromkeys(self.logs, 0)

        torch.save(self.generator.state_dict(),
                   self.log_folder / ('gen_%05d.pytorch' % epoch))
        print('Training has been completed successfully!')
示例#14
0
def main():

	args = cfg.parse_args()

	# write into tensorboard
	log_path = os.path.join(args.demo_path, args.demo_name + '/log')
	vid_path = os.path.join(args.demo_path, args.demo_name + '/vids')
	if not os.path.exists(log_path) and not os.path.exists(vid_path):
		os.makedirs(log_path)
		os.makedirs(vid_path)
	writer = SummaryWriter(log_path)

	device = torch.device("cuda:0")

	G = Generator().to(device)
	G = nn.DataParallel(G)
	G.load_state_dict(torch.load(args.model_path))

	with torch.no_grad():
		G.eval()

		za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) # appearance

		# generating frames from [16, 20, 24, 28, 32, 36, 40, 44, 48]
		for i in range(9):
			zm = torch.randn(args.n_zm_test, args.d_zm, (i+1), 1, 1).to(device) # 16+i*4
			vid_fake = G(za, zm)
			vid_fake = vid_fake.transpose(2,1)
			vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data
			writer.add_video(tag='generated_videos_%dframes'%(16+i*4), global_step=1, vid_tensor=vid_fake)
			writer.flush()

			print('saving videos')
			save_videos(vid_path, vid_fake, args.n_za_test, (16+i*4))

	return
示例#15
0
class Trainer():
    def __init__(
        self,
        data_loaders,
        generator,
        gen_optimizer,
        end_epoch,
        criterion,
        start_epoch=0,
        lr_scheduler=None,
        device=None,
        writer=None,
        debug=False,
        debug_freq=1000,
        logdir='output',
        resume=None,
        performance_type='min',
        num_iters_per_epoch=1000,
    ):

        # Prepare dataloaders
        self.train_2d_loader, self.train_3d_loader, self.valid_loader = data_loaders

        self.train_2d_iter = self.train_3d_iter = None

        if self.train_2d_loader:
            self.train_2d_iter = iter(self.train_2d_loader)

        if self.train_3d_loader:
            self.train_3d_iter = iter(self.train_3d_loader)

        # Models and optimizers
        self.generator = generator
        self.gen_optimizer = gen_optimizer

        # Training parameters
        self.start_epoch = start_epoch
        self.end_epoch = end_epoch
        self.criterion = criterion
        self.lr_scheduler = lr_scheduler
        self.device = device
        self.writer = writer
        self.debug = debug
        self.debug_freq = debug_freq
        self.logdir = logdir

        self.performance_type = performance_type
        self.train_global_step = 0
        self.valid_global_step = 0
        self.epoch = 0
        self.best_performance = float(
            'inf') if performance_type == 'min' else -float('inf')

        self.evaluation_accumulators = dict.fromkeys(
            ['pred_j3d', 'target_j3d', 'target_theta', 'pred_verts'])

        self.num_iters_per_epoch = num_iters_per_epoch

        if self.writer is None:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.logdir)

        if self.device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        # Resume from a pretrained model
        if resume is not None:
            self.resume_pretrained(resume)

    def train(self):
        # Single epoch training routine

        losses = AverageMeter()

        timer = {
            'data': 0,
            'forward': 0,
            'loss': 0,
            'backward': 0,
            'batch': 0,
        }

        self.generator.train()

        start = time.time()

        summary_string = ''

        # bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}', fill='#', max=self.num_iters_per_epoch)
        pbar = tqdm(range(self.num_iters_per_epoch))
        for i in pbar:
            # Dirty solution to reset an iterator
            target_2d = target_3d = None
            if self.train_2d_iter:
                try:
                    target_2d = next(self.train_2d_iter)
                except StopIteration:
                    self.train_2d_iter = iter(self.train_2d_loader)
                    target_2d = next(self.train_2d_iter)

                move_dict_to_device(target_2d, self.device)

            if self.train_3d_iter:
                try:
                    target_3d = next(self.train_3d_iter)
                except StopIteration:
                    self.train_3d_iter = iter(self.train_3d_loader)
                    target_3d = next(self.train_3d_iter)

                move_dict_to_device(target_3d, self.device)

            # <======= Feedforward generator and discriminator
            if target_2d and target_3d:
                inp = torch.cat((target_2d['features'], target_3d['features']),
                                dim=0).to(self.device)
            elif target_3d:
                inp = target_3d['features'].to(self.device)
            else:
                inp = target_2d['features'].to(self.device)

            timer['data'] = time.time() - start
            start = time.time()

            preds = self.generator(inp)

            timer['forward'] = time.time() - start
            start = time.time()

            gen_loss, loss_dict = self.criterion(
                generator_outputs=preds,
                data_2d=target_2d,
                data_3d=target_3d,
            )
            # =======>
            timer['loss'] = time.time() - start
            start = time.time()

            # <======= Backprop generator and discriminator
            self.gen_optimizer.zero_grad()
            gen_loss.backward()
            self.gen_optimizer.step()

            # if self.train_global_step % self.dis_motion_update_steps == 0:
            # self.dis_motion_optimizer.zero_grad()
            # motion_dis_loss.backward()
            # self.dis_motion_optimizer.step()
            # =======>

            # <======= Log training info
            # total_loss = gen_loss + motion_dis_loss
            total_loss = gen_loss

            losses.update(total_loss.item(), inp.size(0))

            timer['backward'] = time.time() - start
            timer['batch'] = timer['data'] + timer['forward'] + timer[
                'loss'] + timer['backward']
            start = time.time()

            # summary_string = f'({i + 1}/{self.num_iters_per_epoch}) | Total: {bar.elapsed_td} | ' \
            #                  f'ETA: {bar.eta_td:} | loss: {losses.avg:.4f}'
            summary_string = '| loss: {:.4f}'.format(losses.avg)

            for k, v in loss_dict.items():
                summary_string += f' | {k}: {v:.2f}'
                self.writer.add_scalar('train_loss/' + k,
                                       v,
                                       global_step=self.train_global_step)

            # for k,v in timer.items():
            #     summary_string += ' | {}: {:.2f}'.format(k, v)

            self.writer.add_scalar('train_loss/loss',
                                   total_loss.item(),
                                   global_step=self.train_global_step)

            if self.debug:
                print('==== Visualize ====')
                from psypose.MEVA.meva.utils.vis import batch_visualize_vid_preds
                video = target_3d['video']
                dataset = 'spin'
                vid_tensor = batch_visualize_vid_preds(video,
                                                       preds[-1],
                                                       target_3d.copy(),
                                                       vis_hmr=False,
                                                       dataset=dataset)
                self.writer.add_video('train-video',
                                      vid_tensor,
                                      global_step=self.train_global_step,
                                      fps=10)

            self.train_global_step += 1
            # bar.suffix = summary_string
            # bar.next()
            pbar.set_description(summary_string)

            if torch.isnan(total_loss):
                exit('Nan value in loss, exiting!...')
            # =======>

        # bar.finish()

        logger.info(summary_string)

    def validate(self):
        self.generator.eval()

        start = time.time()

        summary_string = ''

        bar = Bar('Validation', fill='#', max=len(self.valid_loader))

        if self.evaluation_accumulators is not None:
            for k, v in self.evaluation_accumulators.items():
                self.evaluation_accumulators[k] = []

        J_regressor = torch.from_numpy(
            np.load(osp.join(MEVA_DATA_DIR, 'J_regressor_h36m.npy'))).float()

        for i, target in enumerate(self.valid_loader):

            move_dict_to_device(target, self.device)

            # <=============
            with torch.no_grad():
                inp = target['features']

                preds = self.generator(inp, J_regressor=J_regressor)

                # convert to 14 keypoint format for evaluation
                n_kp = preds[-1]['kp_3d'].shape[-2]
                pred_j3d = preds[-1]['kp_3d'].view(-1, n_kp, 3).cpu().numpy()
                target_j3d = target['kp_3d'].view(-1, n_kp, 3).cpu().numpy()
                pred_verts = preds[-1]['verts'].view(-1, 6890, 3).cpu().numpy()
                target_theta = target['theta'].view(-1, 85).cpu().numpy()

                self.evaluation_accumulators['pred_verts'].append(pred_verts)
                self.evaluation_accumulators['target_theta'].append(
                    target_theta)

                self.evaluation_accumulators['pred_j3d'].append(pred_j3d)
                self.evaluation_accumulators['target_j3d'].append(target_j3d)
            # =============>

            # <============= DEBUG
            if self.debug and self.valid_global_step % self.debug_freq == 0:
                from psypose.MEVA.meva.utils.vis import batch_visualize_vid_preds
                video = target['video']
                dataset = 'common'
                vid_tensor = batch_visualize_vid_preds(video,
                                                       preds[-1],
                                                       target,
                                                       vis_hmr=False,
                                                       dataset=dataset)
                self.writer.add_video('valid-video',
                                      vid_tensor,
                                      global_step=self.valid_global_step,
                                      fps=10)
            # =============>

            batch_time = time.time() - start

            summary_string = f'({i + 1}/{len(self.valid_loader)}) | batch: {batch_time * 10.0:.4}ms | ' \
                             f'Total: {bar.elapsed_td} | ETA: {bar.eta_td:}'

            self.valid_global_step += 1
            bar.suffix = summary_string
            bar.next()

        bar.finish()

        logger.info(summary_string)

    def fit(self):

        for epoch in range(self.start_epoch, self.end_epoch):
            self.epoch = epoch
            self.train()
            self.validate()
            performance = self.evaluate()

            if self.lr_scheduler is not None:
                self.lr_scheduler.step(performance)

            # log the learning rate
            for param_group in self.gen_optimizer.param_groups:
                print(f'Learning rate {param_group["lr"]}')
                self.writer.add_scalar('lr/gen_lr',
                                       param_group['lr'],
                                       global_step=self.epoch)

            # for param_group in self.dis_motion_optimizer.param_groups:
            #     print(f'Learning rate {param_group["lr"]}')
            #     self.writer.add_scalar('lr/dis_lr', param_group['lr'], global_step=self.epoch)

            logger.info(f'Epoch {epoch+1} performance: {performance:.4f}')

            self.save_model(performance, epoch)

            # if performance > 200.0:
            # exit(f'MPJPE error is {performance}, higher than 80.0. Exiting!...')

        self.writer.close()

    def save_model(self, performance, epoch):
        save_dict = {
            'epoch': epoch,
            'gen_state_dict': self.generator.state_dict(),
            'performance': performance,
            'gen_optimizer': self.gen_optimizer.state_dict(),
            # 'disc_motion_state_dict': self.motion_discriminator.state_dict(),
            # 'disc_motion_optimizer': self.dis_motion_optimizer.state_dict(),
        }

        filename = osp.join(self.logdir, 'checkpoint.pth.tar')
        torch.save(save_dict, filename)

        if self.performance_type == 'min':
            is_best = performance < self.best_performance
        else:
            is_best = performance > self.best_performance

        if is_best:
            logger.info('Best performance achived, saving it!')
            self.best_performance = performance
            shutil.copyfile(filename,
                            osp.join(self.logdir, 'model_best.pth.tar'))

            with open(osp.join(self.logdir, 'best.txt'), 'w') as f:
                f.write(str(float(performance)))

    def resume_pretrained(self, model_path):
        if osp.isfile(model_path):
            checkpoint = torch.load(model_path)
            self.start_epoch = checkpoint['epoch']
            self.generator.load_state_dict(checkpoint['gen_state_dict'])
            self.gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
            self.best_performance = checkpoint['performance']

            # if 'disc_motion_optimizer' in checkpoint.keys():
            # self.motion_discriminator.load_state_dict(checkpoint['disc_motion_state_dict'])
            # self.dis_motion_optimizer.load_state_dict(checkpoint['disc_motion_optimizer'])

            logger.info(
                f"=> loaded checkpoint '{model_path}' "
                f"(epoch {self.start_epoch}, performance {self.best_performance})"
            )
        else:
            logger.info(f"=> no checkpoint found at '{model_path}'")

    def evaluate(self):

        for k, v in self.evaluation_accumulators.items():
            self.evaluation_accumulators[k] = np.vstack(v)

        pred_j3ds = self.evaluation_accumulators['pred_j3d']
        target_j3ds = self.evaluation_accumulators['target_j3d']

        pred_j3ds = torch.from_numpy(pred_j3ds).float()
        target_j3ds = torch.from_numpy(target_j3ds).float()

        print(f'Evaluating on {pred_j3ds.shape[0]} number of poses...')
        pred_pelvis = (pred_j3ds[:, [2], :] + pred_j3ds[:, [3], :]) / 2.0
        target_pelvis = (target_j3ds[:, [2], :] + target_j3ds[:, [3], :]) / 2.0

        pred_j3ds -= pred_pelvis
        target_j3ds -= target_pelvis
        # Absolute error (MPJPE)
        errors = torch.sqrt(
            ((pred_j3ds -
              target_j3ds)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
        S1_hat = batch_compute_similarity_transform_torch(
            pred_j3ds, target_j3ds)
        errors_pa = torch.sqrt(
            ((S1_hat -
              target_j3ds)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
        pred_verts = self.evaluation_accumulators['pred_verts']
        target_theta = self.evaluation_accumulators['target_theta']

        m2mm = 1000

        pve = np.mean(
            compute_error_verts(target_theta=target_theta,
                                pred_verts=pred_verts)) * m2mm
        accel = np.mean(compute_accel(pred_j3ds)) * m2mm
        accel_err = np.mean(
            compute_error_accel(joints_pred=pred_j3ds,
                                joints_gt=target_j3ds)) * m2mm
        mpjpe = np.mean(errors) * m2mm
        pa_mpjpe = np.mean(errors_pa) * m2mm

        eval_dict = {
            'mpjpe': mpjpe,
            'pa-mpjpe': pa_mpjpe,
            'accel': accel,
            'pve': pve,
            'accel_err': accel_err
        }

        log_str = f'Epoch {self.epoch}, '
        log_str += ' '.join(
            [f'{k.upper()}: {v:.4f},' for k, v in eval_dict.items()])
        logger.info(log_str)

        for k, v in eval_dict.items():
            self.writer.add_scalar(f'error/{k}', v, global_step=self.epoch)

        return pa_mpjpe
示例#16
0
class Logger(object):
    def __init__(self, log_dir, comment=''):
        self.writer = SummaryWriter(log_dir=log_dir, comment=comment)
        self.imgs_dict = {}

    def scalar_summary(self, tag, value, step):
        self.writer.add_scalar(tag, value, global_step=step)
        self.writer.flush()

    def combined_scalars_summary(self, main_tag, tag_scalar_dict, step):
        self.writer.add_scalars(main_tag, tag_scalar_dict, step)
        self.writer.flush()

    def log(self, tag, text_string, step=0):
        self.writer.add_text(tag, text_string, step)
        self.writer.flush()

    def log_model(self, model, inputs):
        self.writer.add_graph(model, inputs)
        self.writer.flush()

    def get_dir(self):
        return self.writer.get_logdir()

    def log_model_state(self, model, name='tmp'):
        path = os.path.join(self.writer.get_logdir(),
                            type(model).__name__ + '_%s.pt' % name)
        torch.save(model.state_dict(), path)

    def log_video(self,
                  tag,
                  global_step=None,
                  img_tns=None,
                  finished_video=False,
                  video_tns=None,
                  debug=False):
        '''
            Logs video to tensorboard. 
            Video_tns will be empty. If given image tensors, then when finished_video = True, the video of the past tensors will be made into one video. 
            If vide_tns is not empty, then that will be marked the video and the other arguments will be ignored. 
        '''
        if debug:
            import pdb
            pdb.set_trace()
        if img_tns is None and video_tns is None:
            if not finished_video or tag not in self.imgs_dict.keys():
                return None
            lst_img_tns = self.imgs_dict[tag]
            self.writer.add_video(tag,
                                  torch.tensor(lst_img_tns),
                                  global_step=global_step,
                                  fps=4)
            self.writer.flush()
            self.imgs_dict[tag] = []
            return None
        elif video_tns is not None:
            self.writer.add_video(tag,
                                  video_tns,
                                  global_step=global_step,
                                  fps=4)
            self.writer.flush()
            return None

        if tag in self.imgs_dict.keys():
            lst_img_tns = self.imgs_dict[tag]
        else:
            lst_img_tns = []
            self.imgs_dict[tag] = lst_img_tns

        lst_img_tns.append(img_tns)

        if finished_video:
            self.writer.add_video(tag,
                                  torch.tensor(lst_img_tns),
                                  global_step=global_step,
                                  fps=4)
            self.writer.flush()
            self.imgs_dict[tag].clear()

    def close(self):
        self.writer.close()
示例#17
0
class Logger(object):
    def __init__(self,
                 log_dir,
                 save_tb=False,
                 log_frequency=10000,
                 agent="sac"):
        self._log_dir = log_dir
        self._log_frequency = log_frequency
        if save_tb:
            tb_dir = os.path.join(log_dir, "tb")
            if os.path.exists(tb_dir):
                try:
                    shutil.rmtree(tb_dir)
                except:
                    print("logger.py warning: Unable to remove tb directory")
                    pass
            self._sw = SummaryWriter(tb_dir)
        else:
            self._sw = None
        # each agent has specific output format for training
        assert agent in AGENT_TRAIN_FORMAT
        train_format = COMMON_TRAIN_FORMAT + AGENT_TRAIN_FORMAT[agent]
        self._train_mg = MetersGroup(os.path.join(log_dir, "train"),
                                     formating=train_format)
        self._eval_mg = MetersGroup(os.path.join(log_dir, "eval"),
                                    formating=COMMON_EVAL_FORMAT)

    def _should_log(self, step, log_frequency):
        log_frequency = log_frequency or self._log_frequency
        return step % log_frequency == 0

    def _try_sw_log(self, key, value, step):
        if self._sw is not None:
            self._sw.add_scalar(key, value, step)

    def _try_sw_log_video(self, key, frames, step):
        if self._sw is not None:
            frames = torch.from_numpy(np.array(frames))
            frames = frames.unsqueeze(0)
            self._sw.add_video(key, frames, step, fps=30)

    def _try_sw_log_histogram(self, key, histogram, step):
        if self._sw is not None:
            self._sw.add_histogram(key, histogram, step)

    def log(self, key, value, step, n=1, log_frequency=1):
        if not self._should_log(step, log_frequency):
            return
        assert key.startswith("train") or key.startswith("eval")
        if type(value) == torch.Tensor:
            value = value.item()
        self._try_sw_log(key, value / n, step)
        mg = self._train_mg if key.startswith("train") else self._eval_mg
        mg.log(key, value, n)

    def log_param(self, key, param, step, log_frequency=None):
        if not self._should_log(step, log_frequency):
            return
        self.log_histogram(key + "_w", param.weight.data, step)
        if hasattr(param.weight, "grad") and param.weight.grad is not None:
            self.log_histogram(key + "_w_g", param.weight.grad.data, step)
        if hasattr(param, "bias") and hasattr(param.bias, "data"):
            self.log_histogram(key + "_b", param.bias.data, step)
            if hasattr(param.bias, "grad") and param.bias.grad is not None:
                self.log_histogram(key + "_b_g", param.bias.grad.data, step)

    def log_video(self, key, frames, step, log_frequency=None):
        if not self._should_log(step, log_frequency):
            return
        assert key.startswith("train") or key.startswith("eval")
        self._try_sw_log_video(key, frames, step)

    def log_histogram(self, key, histogram, step, log_frequency=None):
        if not self._should_log(step, log_frequency):
            return
        assert key.startswith("train") or key.startswith("eval")
        self._try_sw_log_histogram(key, histogram, step)

    def dump(self, step, save=True, ty=None):
        if ty is None:
            self._train_mg.dump(step, "train", save)
            self._eval_mg.dump(step, "eval", save)
        elif ty == "eval":
            self._eval_mg.dump(step, "eval", save)
        elif ty == "train":
            self._train_mg.dump(step, "train", save)
        else:
            raise f"invalid log type: {ty}"
示例#18
0
class A2CTrial(PyTorchTrial):
    def __init__(self, trial_context: PyTorchTrialContext) -> None:
        self.context = trial_context
        self.download_directory = f"/tmp/data-rank{self.context.distributed.get_rank()}"
        # self.logger = TorchWriter()
        self.n_stack = self.context.get_hparam("n_stack")
        self.env_name = self.context.get_hparam("env_name")
        self.num_envs = self.context.get_hparam("num_envs")
        self.rollout_size = self.context.get_hparam("rollout_size")
        self.curiousity = self.context.get_hparam("curiousity")
        self.lr = self.context.get_hparam("lr")
        self.icm_beta = self.context.get_hparam("icm_beta")
        self.value_coeff = self.context.get_hparam("value_coeff")
        self.entropy_coeff = self.context.get_hparam("entropy_coeff")
        self.max_grad_norm = self.context.get_hparam("max_grad_norm")

        env = make_atari_env(self.env_name, num_env=self.num_envs, seed=42)
        self.env = VecFrameStack(env, n_stack=self.n_stack)
        eval_env = make_atari_env(self.env_name, num_env=1, seed=42)
        self.eval_env = VecFrameStack(eval_env, n_stack=self.n_stack)

        # constants
        self.in_size = self.context.get_hparam("in_size")  # in_size
        self.num_actions = env.action_space.n

        def init_(m):
            return init(m, nn.init.orthogonal_,
                        lambda x: nn.init.constant_(x, 0))

        self.feat_enc_net = self.context.Model(
            FeatureEncoderNet(self.n_stack, self.in_size))
        self.actor = self.context.Model(
            init_(nn.Linear(self.feat_enc_net.hidden_size, self.num_actions)))
        self.critic = self.context.Model(
            init_(nn.Linear(self.feat_enc_net.hidden_size, 1)))
        self.set_recurrent_buffers(self.num_envs)

        params = list(self.feat_enc_net.parameters()) + list(
            self.actor.parameters()) + list(self.critic.parameters())
        self.opt = self.context.Optimizer(torch.optim.Adam(params, self.lr))

        self.is_cuda = torch.cuda.is_available()
        self.storage = RolloutStorage(self.rollout_size,
                                      self.num_envs,
                                      self.env.observation_space.shape[0:-1],
                                      self.n_stack,
                                      is_cuda=self.is_cuda,
                                      value_coeff=self.value_coeff,
                                      entropy_coeff=self.entropy_coeff)

        obs = self.env.reset()
        self.storage.states[0].copy_(self.storage.obs2tensor(obs))

        self.writer = SummaryWriter(log_dir="/tmp/tensorboard")
        self.global_eval_count = 0

    def set_recurrent_buffers(self, buf_size):
        self.feat_enc_net.reset_lstm(buf_size=buf_size)

    def reset_recurrent_buffers(self, reset_indices):
        self.feat_enc_net.reset_lstm(reset_indices=reset_indices)

    def build_training_data_loader(self) -> DataLoader:
        ds = torchvision.datasets.MNIST(
            self.download_directory,
            train=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # These are the precomputed mean and standard deviation of the
                # MNIST data; this normalizes the data to have zero mean and unit
                # standard deviation.
                transforms.Normalize((0.1307, ), (0.3081, )),
            ]),
            download=True)
        return DataLoader(ds, batch_size=1)

    def build_validation_data_loader(self) -> DataLoader:
        ds = torchvision.datasets.MNIST(
            self.download_directory,
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # These are the precomputed mean and standard deviation of the
                # MNIST data; this normalizes the data to have zero mean and unit
                # standard deviation.
                transforms.Normalize((0.1307, ), (0.3081, )),
            ]),
            download=True)
        return DataLoader(ds, batch_size=1)

    def train_batch(self, batch: TorchData, model: nn.Module, epoch_idx: int,
                    batch_idx: int) -> Dict[str, torch.Tensor]:
        final_value, entropy = self.episode_rollout()
        self.opt.zero_grad()
        total_loss, value_loss, policy_loss, entropy_loss = self.storage.a2c_loss(
            final_value, entropy)
        self.context.backward(total_loss)

        def clip_grads(parameters):
            torch.nn.utils.clip_grad_norm_(parameters, self.max_grad_norm)

        self.context.step_optimizer(self.opt, clip_grads)
        self.storage.after_update()

        return {
            'loss': total_loss,
            'value_loss': value_loss,
            'policy_loss': policy_loss,
            'entropy_loss': entropy_loss
        }

    def get_action(self, state, deterministic=False):
        feature = self.feat_enc_net(state)

        # calculate policy and value function
        policy = self.actor(feature)
        value = torch.squeeze(self.critic(feature))

        action_prob = F.softmax(policy, dim=-1)
        cat = Categorical(action_prob)
        if not deterministic:
            action = cat.sample()
            return (action, cat.log_prob(action), cat.entropy().mean(), value,
                    feature)
        else:
            action = np.argmax(action_prob.detach().cpu().numpy(), axis=1)
            return (action, [], [], value, feature)

    def episode_rollout(self):
        episode_entropy = 0
        for step in range(self.rollout_size):
            """Interact with the environments """
            # call A2C
            a_t, log_p_a_t, entropy, value, a2c_features = self.get_action(
                self.storage.get_state(step))
            # accumulate episode entropy
            episode_entropy += entropy

            # interact
            obs, rewards, dones, infos = self.env.step(a_t.cpu().numpy())

            # save episode reward
            self.storage.log_episode_rewards(infos)

            self.storage.insert(step, rewards, obs, a_t, log_p_a_t, value,
                                dones)
            self.reset_recurrent_buffers(reset_indices=dones)

        # Note:
        # get the estimate of the final reward
        # that's why we have the CRITIC --> estimate final reward
        # detach, as the final value will only be used as a
        with torch.no_grad():
            state = self.storage.get_state(step + 1)
            final_features = self.feat_enc_net(state)
            final_value = torch.squeeze(self.critic(final_features))

        return final_value, episode_entropy

    def evaluate_full_dataset(self, data_loader, model) -> Dict[str, Any]:
        self.global_eval_count += 1
        episode_rewards, episode_lengths = [], []
        n_eval_episodes = 10
        self.set_recurrent_buffers(1)
        frames = []
        with torch.no_grad():
            for episode in range(n_eval_episodes):
                obs = self.eval_env.reset()
                done, state = False, None
                episode_reward = 0.0
                episode_length = 0
                while not done:
                    state = self.storage.obs2tensor(obs)
                    if episode == 0:
                        frame = torch.unsqueeze(torch.squeeze(state)[0],
                                                0).detach()
                        frames.append(frame)
                    action, _, _, _, _ = self.get_action(state,
                                                         deterministic=True)
                    obs, reward, done, _info = self.eval_env.step(action)
                    reward = reward[0]
                    done = done[0]
                    episode_reward += reward
                    episode_length += 1
                if episode == 0:
                    video = torch.unsqueeze(torch.stack(frames), 0)
                    self.writer.add_video('policy',
                                          video,
                                          global_step=self.global_eval_count,
                                          fps=20)
                episode_rewards.append(episode_reward)
                episode_lengths.append(episode_length)

        mean_reward = np.mean(episode_rewards)
        std_reward = np.std(episode_rewards)
        self.set_recurrent_buffers(self.num_envs)
        return {'mean_reward': mean_reward}
示例#19
0
class TensorboardWriter(object):
    """
    Helper class to log information to Tensorboard.
    """
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode): configs. Details can be found in
                slowfast/config/defaults.py
        """
        # class_names: list of class names.
        # cm_subset_classes: a list of class ids -- a user-specified subset.
        # parent_map: dictionary where key is the parent class name and
        #   value is a list of ids of its children classes.
        # hist_subset_classes: a list of class ids -- user-specified to plot histograms.
        (
            self.class_names,
            self.cm_subset_classes,
            self.parent_map,
            self.hist_subset_classes,
        ) = (None, None, None, None)
        self.cfg = cfg
        self.cm_figsize = cfg.TENSORBOARD.CONFUSION_MATRIX.FIGSIZE
        self.hist_figsize = cfg.TENSORBOARD.HISTOGRAM.FIGSIZE

        if cfg.TENSORBOARD.LOG_DIR == "":
            log_dir = os.path.join(cfg.OUTPUT_DIR,
                                   "runs-{}".format(cfg.TRAIN.DATASET))
        else:
            log_dir = os.path.join(cfg.OUTPUT_DIR, cfg.TENSORBOARD.LOG_DIR)

        self.writer = SummaryWriter(log_dir=log_dir)
        logger.info(
            "To see logged results in Tensorboard, please launch using the command \
            `tensorboard  --port=<port-number> --logdir {}`".format(log_dir))

        if cfg.TENSORBOARD.CLASS_NAMES_PATH != "":
            if cfg.DETECTION.ENABLE:
                logger.info("Plotting confusion matrix is currently \
                    not supported for detection.")
            (
                self.class_names,
                self.parent_map,
                self.cm_subset_classes,
            ) = get_class_names(
                cfg.TENSORBOARD.CLASS_NAMES_PATH,
                cfg.TENSORBOARD.CATEGORIES_PATH,
                cfg.TENSORBOARD.CONFUSION_MATRIX.SUBSET_PATH,
            )

            if cfg.TENSORBOARD.HISTOGRAM.ENABLE:
                if cfg.DETECTION.ENABLE:
                    logger.info("Plotting histogram is not currently \
                    supported for detection tasks.")
                if cfg.TENSORBOARD.HISTOGRAM.SUBSET_PATH != "":
                    _, _, self.hist_subset_classes = get_class_names(
                        cfg.TENSORBOARD.CLASS_NAMES_PATH,
                        None,
                        cfg.TENSORBOARD.HISTOGRAM.SUBSET_PATH,
                    )

    def add_scalars(self, data_dict, global_step=None):
        """
        Add multiple scalars to Tensorboard logs.
        Args:
            data_dict (dict): key is a string specifying the tag of value.
            global_step (Optinal[int]): Global step value to record.
        """
        if self.writer is not None:
            for key, item in data_dict.items():
                self.writer.add_scalar(key, item, global_step)

    def plot_eval(self, preds, labels, global_step=None):
        """
        Plot confusion matrices and histograms for eval/test set.
        Args:
            preds (tensor or list of tensors): list of predictions.
            labels (tensor or list of tensors): list of labels.
            global step (Optional[int]): current step in eval/test.
        """
        if not self.cfg.DETECTION.ENABLE:
            cmtx = None
            if self.cfg.TENSORBOARD.CONFUSION_MATRIX.ENABLE:
                cmtx = vis_utils.get_confusion_matrix(
                    preds, labels, self.cfg.MODEL.NUM_CLASSES)
                # Add full confusion matrix.
                add_confusion_matrix(
                    self.writer,
                    cmtx,
                    self.cfg.MODEL.NUM_CLASSES,
                    global_step=global_step,
                    class_names=self.class_names,
                    figsize=self.cm_figsize,
                )
                # If a list of subset is provided, plot confusion matrix subset.
                if self.cm_subset_classes is not None:
                    add_confusion_matrix(
                        self.writer,
                        cmtx,
                        self.cfg.MODEL.NUM_CLASSES,
                        global_step=global_step,
                        subset_ids=self.cm_subset_classes,
                        class_names=self.class_names,
                        tag="Confusion Matrix Subset",
                        figsize=self.cm_figsize,
                    )
                # If a parent-child classes mapping is provided, plot confusion
                # matrices grouped by parent classes.
                if self.parent_map is not None:
                    # Get list of tags (parent categories names) and their children.
                    for parent_class, children_ls in self.parent_map.items():
                        tag = (
                            "Confusion Matrices Grouped by Parent Classes/" +
                            parent_class)
                        add_confusion_matrix(
                            self.writer,
                            cmtx,
                            self.cfg.MODEL.NUM_CLASSES,
                            global_step=global_step,
                            subset_ids=children_ls,
                            class_names=self.class_names,
                            tag=tag,
                            figsize=self.cm_figsize,
                        )
            if self.cfg.TENSORBOARD.HISTOGRAM.ENABLE:
                if cmtx is None:
                    cmtx = vis_utils.get_confusion_matrix(
                        preds, labels, self.cfg.MODEL.NUM_CLASSES)
                plot_hist(
                    self.writer,
                    cmtx,
                    self.cfg.MODEL.NUM_CLASSES,
                    self.cfg.TENSORBOARD.HISTOGRAM.TOPK,
                    global_step=global_step,
                    subset_ids=self.hist_subset_classes,
                    class_names=self.class_names,
                    figsize=self.hist_figsize,
                )

    def add_video(self,
                  vid_tensor,
                  tag="Video Input",
                  global_step=None,
                  fps=4):
        """
        Add input to tensorboard SummaryWriter as a video.
        Args:
            vid_tensor (tensor): shape of (B, T, C, H, W). Values should lie
                [0, 255] for type uint8 or [0, 1] for type float.
            tag (Optional[str]): name of the video.
            global_step(Optional[int]): current step.
            fps (int): frames per second.
        """
        self.writer.add_video(tag,
                              vid_tensor,
                              global_step=global_step,
                              fps=fps)

    def plot_weights_and_activations(
        self,
        weight_activation_dict,
        tag="",
        normalize=False,
        global_step=None,
        batch_idx=None,
        indexing_dict=None,
        heat_map=True,
    ):
        """
        Visualize weights/ activations tensors to Tensorboard.
        Args:
            weight_activation_dict (dict[str, tensor]): a dictionary of the pair {layer_name: tensor},
                where layer_name is a string and tensor is the weights/activations of
                the layer we want to visualize.
            tag (Optional[str]): name of the video.
            normalize (bool): If True, the tensor is normalized. (Default to False)
            global_step(Optional[int]): current step.
            batch_idx (Optional[int]): current batch index to visualize. If None,
                visualize the entire batch.
            indexing_dict (Optional[dict]): a dictionary of the {layer_name: indexing}.
                where indexing is numpy-like fancy indexing.
            heatmap (bool): whether to add heatmap to the weights/ activations.
        """
        for name, array in weight_activation_dict.items():
            if batch_idx is None:
                # Select all items in the batch if batch_idx is not provided.
                batch_idx = list(range(array.shape[0]))
            if indexing_dict is not None:
                fancy_indexing = indexing_dict[name]
                fancy_indexing = (batch_idx, ) + fancy_indexing
                array = array[fancy_indexing]
            else:
                array = array[batch_idx]
            add_ndim_array(
                self.writer,
                array,
                tag + name,
                normalize=normalize,
                global_step=global_step,
                heat_map=heat_map,
            )

    def flush(self):
        self.writer.flush()

    def close(self):
        self.writer.flush()
        self.writer.close()
示例#20
0
class Logger:
    _count = 0

    def __init__(self, scrn=True, log_dir='', phase=''):
        super().__init__()
        self._logger = logging.getLogger('logger_{}'.format(Logger._count))
        Logger._count += 1
        self._logger.setLevel(logging.DEBUG)

        if scrn:
            self._scrn_handler = logging.StreamHandler()
            self._scrn_handler.setLevel(logging.INFO)
            self._scrn_handler.setFormatter(
                logging.Formatter(fmt=FORMAT_SHORT))
            self._logger.addHandler(self._scrn_handler)

        if log_dir and phase:
            self.log_path = os.path.join(
                log_dir,
                '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format(
                    phase,
                    *localtime()[:6]))
            self.show_nl("log into {}\n\n".format(self.log_path))
            self._file_handler = logging.FileHandler(filename=self.log_path)
            self._file_handler.setLevel(logging.DEBUG)
            self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG))
            self._logger.addHandler(self._file_handler)

            self._writer = SummaryWriter(log_dir=os.path.join(
                log_dir, '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}'.format(
                    phase,
                    *localtime()[:6])))

    def show(self, *args, **kwargs):
        return self._logger.info(*args, **kwargs)

    def show_nl(self, *args, **kwargs):
        self._logger.info("")
        return self.show(*args, **kwargs)

    def dump(self, *args, **kwargs):
        return self._logger.debug(*args, **kwargs)

    def warning(self, *args, **kwargs):
        return self._logger.warning(*args, **kwargs)

    def error(self, *args, **kwargs):
        return self._logger.error(*args, **kwargs)

    # tensorboard
    def add_scalar(self, *args, **kwargs):
        return self._writer.add_scalar(*args, **kwargs)

    def add_scalars(self, *args, **kwargs):
        return self._writer.add_scalars(*args, **kwargs)

    def add_histogram(self, *args, **kwargs):
        return self._writer.add_histogram(*args, **kwargs)

    def add_image(self, *args, **kwargs):
        return self._writer.add_image(*args, **kwargs)

    def add_images(self, *args, **kwargs):
        return self._writer.add_images(*args, **kwargs)

    def add_figure(self, *args, **kwargs):
        return self._writer.add_figure(*args, **kwargs)

    def add_video(self, *args, **kwargs):
        return self._writer.add_video(*args, **kwargs)

    def add_audio(self, *args, **kwargs):
        return self._writer.add_audio(*args, **kwargs)

    def add_text(self, *args, **kwargs):
        return self._writer.add_text(*args, **kwargs)

    def add_graph(self, *args, **kwargs):
        return self._writer.add_graph(*args, **kwargs)

    def add_pr_curve(self, *args, **kwargs):
        return self._writer.add_pr_curve(*args, **kwargs)

    def add_custom_scalars(self, *args, **kwargs):
        return self._writer.add_custom_scalars(*args, **kwargs)

    def add_mesh(self, *args, **kwargs):
        return self._writer.add_mesh(*args, **kwargs)

    # def add_hparams(self, *args, **kwargs):
    #     return self._writer.add_hparams(*args, **kwargs)

    def flush(self):
        return self._writer.flush()

    def close(self):
        return self._writer.close()

    def _grad_hook(self, grad, name=None, grads=None):
        grads.update({name: grad})

    def watch_grad(self, model, layers):
        """
        Add hooks to the specific layers. Gradients of these layers will save to self.grads
        :param model:
        :param layers: Except a list eg. layers=[0, -1] means to watch the gradients of
                        the fist layer and the last layer of the model
        :return:
        """
        assert layers
        if not hasattr(self, 'grads'):
            self.grads = {}
            self.grad_hooks = {}
        named_parameters = list(model.named_parameters())
        for layer in layers:
            name = named_parameters[layer][0]
            handle = named_parameters[layer][1].register_hook(
                functools.partial(self._grad_hook, name=name,
                                  grads=self.grads))
            self.grad_hooks.update(dict(name=handle))

    def watch_grad_close(self):
        for _, handle in self.grad_hooks.items():
            handle.remove()  # remove the hook

    def add_grads(self, global_step=None, *args, **kwargs):
        """
        Add gradients to tensorboard. You must call the method self.watch_grad before using this method!
        """
        assert  hasattr(self, 'grads'),\
        "self.grads is nonexisent! You must call self.watch_grad before!"
        assert self.grads, "self.grads if empty!"
        for (name, grad) in self.grads.items():
            self.add_histogram(tag=name,
                               values=grad,
                               global_step=global_step,
                               *args,
                               **kwargs)

    @staticmethod
    def make_desc(counter, total, *triples):
        desc = "[{}/{}]".format(counter, total)
        # The three elements of each triple are
        # (name to display, AverageMeter object, formatting string)
        for name, obj, fmt in triples:
            desc += (" {} {obj.val:" + fmt + "} ({obj.avg:" + fmt +
                     "})").format(name, obj=obj)
        return desc
示例#21
0
class Logger(object):
    def __init__(self,
                 log_dir,
                 save_tb=False,
                 log_frequency=10000,
                 agent='sac'):
        self._log_dir = log_dir
        self._log_frequency = log_frequency
        if save_tb:
            tb_dir = os.path.join(log_dir, 'tb')
            if os.path.exists(tb_dir):
                try:
                    shutil.rmtree(tb_dir)
                except:
                    print("logger.py warning: Unable to remove tb directory")
                    pass
            self._sw = SummaryWriter(tb_dir)
        else:
            self._sw = None
        # each agent has specific output format for training
        assert agent in AGENT_TRAIN_FORMAT
        train_format = COMMON_TRAIN_FORMAT + AGENT_TRAIN_FORMAT[agent]
        self._train_mg = MetersGroup(os.path.join(log_dir, 'train'),
                                     formating=train_format)
        self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval'),
                                    formating=COMMON_EVAL_FORMAT)

    def _should_log(self, step, log_frequency):
        log_frequency = log_frequency or self._log_frequency
        return step % log_frequency == 0

    def _try_sw_log(self, key, value, step):
        if self._sw is not None:
            self._sw.add_scalar(key, value, step)

    def _try_sw_log_image(self, key, image, step):
        if self._sw is not None:
            assert image.dim() == 3
            grid = torchvision.utils.make_grid(image.unsqueeze(1))
            self._sw.add_image(key, grid, step)

    def _try_sw_log_video(self, key, frames, step):
        if self._sw is not None:
            frames = torch.from_numpy(np.array(frames))
            frames = frames.unsqueeze(0)
            self._sw.add_video(key, frames, step, fps=30)

    def _try_sw_log_histogram(self, key, histogram, step):
        if self._sw is not None:
            self._sw.add_histogram(key, histogram, step)

    def log(self, key, value, step, n=1, log_frequency=1):
        if not self._should_log(step, log_frequency):
            return
        assert key.startswith('train') or key.startswith('eval')
        if type(value) == torch.Tensor:
            value = value.item()
        self._try_sw_log(key, value / n, step)
        mg = self._train_mg if key.startswith('train') else self._eval_mg
        mg.log(key, value, n)

    def log_param(self, key, param, step, log_frequency=None):
        if not self._should_log(step, log_frequency):
            return
        self.log_histogram(key + '_w', param.weight.data, step)
        if hasattr(param.weight, 'grad') and param.weight.grad is not None:
            self.log_histogram(key + '_w_g', param.weight.grad.data, step)
        if hasattr(param, 'bias') and hasattr(param.bias, 'data'):
            self.log_histogram(key + '_b', param.bias.data, step)
            if hasattr(param.bias, 'grad') and param.bias.grad is not None:
                self.log_histogram(key + '_b_g', param.bias.grad.data, step)

    def log_image(self, key, image, step, log_frequency=None):
        if not self._should_log(step, log_frequency):
            return
        assert key.startswith('train') or key.startswith('eval')
        self._try_sw_log_image(key, image, step)

    def log_video(self, key, frames, step, log_frequency=None):
        if not self._should_log(step, log_frequency):
            return
        assert key.startswith('train') or key.startswith('eval')
        self._try_sw_log_video(key, frames, step)

    def log_histogram(self, key, histogram, step, log_frequency=None):
        if not self._should_log(step, log_frequency):
            return
        assert key.startswith('train') or key.startswith('eval')
        self._try_sw_log_histogram(key, histogram, step)

    def dump(self, step, save=True, ty=None):
        if ty is None:
            self._train_mg.dump(step, 'train', save)
            self._eval_mg.dump(step, 'eval', save)
        elif ty == 'eval':
            self._eval_mg.dump(step, 'eval', save)
        elif ty == 'train':
            self._train_mg.dump(step, 'train', save)
        else:
            raise f'invalid log type: {ty}'
class DisentanglingTrainer(LatentTrainer):
    def __init__(
        self,
        env,
        log_dir,
        state_rep,
        skill_policy_path,
        seed,
        run_id,
        feature_dim=5,
        num_sequences=80,
        cuda=False,
    ):
        parent_kwargs = dict(
            num_steps=3000000,
            initial_latent_steps=100000,
            batch_size=256,
            latent_batch_size=128,
            num_sequences=num_sequences,
            latent_lr=0.0001,
            feature_dim=feature_dim,  # Note: Only used if state_rep == False
            latent1_dim=8,
            latent2_dim=32,
            mode_dim=2,
            hidden_units=[256, 256],
            hidden_units_decoder=[256, 256],
            hidden_units_mode_encoder=[256, 256],
            hidden_rnn_dim=64,
            rnn_layers=2,
            memory_size=1e5,
            leaky_slope=0.2,
            grad_clip=None,
            start_steps=10000,
            training_log_interval=100,
            learning_log_interval=50,
            cuda=cuda,
            seed=seed)

        # Other
        self.run_id = run_id
        self.state_rep = state_rep

        # Comment for summery writer
        summary_comment = self.run_id

        # Environment
        self.env = env
        self.observation_shape = self.env.observation_space.shape
        self.action_shape = self.env.action_space.shape
        self.action_repeat = self.env.action_repeat

        # Seed
        torch.manual_seed(parent_kwargs['seed'])
        np.random.seed(parent_kwargs['seed'])
        self.env.seed(parent_kwargs['seed'])

        # Device
        self.device = torch.device("cuda" if parent_kwargs['cuda']
                                   and torch.cuda.is_available() else "cpu")

        # Latent Network
        self.latent = ModeDisentanglingNetwork(
            self.observation_shape,
            self.action_shape,
            feature_dim=parent_kwargs['feature_dim'],
            latent1_dim=parent_kwargs['latent1_dim'],
            latent2_dim=parent_kwargs['latent2_dim'],
            mode_dim=parent_kwargs['mode_dim'],
            hidden_units=parent_kwargs['hidden_units'],
            hidden_units_decoder=parent_kwargs['hidden_units_decoder'],
            hidden_units_mode_encoder=parent_kwargs[
                'hidden_units_mode_encoder'],
            rnn_layers=parent_kwargs['rnn_layers'],
            hidden_rnn_dim=parent_kwargs['hidden_rnn_dim'],
            leaky_slope=parent_kwargs['leaky_slope'],
            state_rep=state_rep).to(self.device)

        # Load pretrained DIAYN skill policy
        data = torch.load(skill_policy_path)
        self.policy = data['evaluation/policy']
        print("Policy loaded")

        # MI-Gradient score estimators
        self.spectral_j = SpectralScoreEstimator(n_eigen_threshold=0.99)
        self.spectral_m = SpectralScoreEstimator(n_eigen_threshold=0.99)

        # Optimization
        self.latent_optim = Adam(self.latent.parameters(),
                                 lr=parent_kwargs['latent_lr'])

        # Memory
        self.memory = MyMemoryDisentangling(
            state_rep=state_rep,
            capacity=parent_kwargs['memory_size'],
            num_sequences=parent_kwargs['num_sequences'],
            observation_shape=self.observation_shape,
            action_shape=self.action_shape,
            device=self.device)

        # Log directories
        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        self.images_dir = os.path.join(log_dir, 'images')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)
        if not os.path.exists(self.images_dir):
            os.makedirs(self.images_dir)

        # Summary writer with conversion of hparams
        # (certain types are not aloud for hparam-storage)
        self.writer = SummaryWriter(os.path.join(self.summary_dir,
                                                 summary_comment),
                                    filename_suffix=self.run_id)
        hparam_dict = parent_kwargs.copy()
        for k, v in hparam_dict.items():
            if isinstance(v, type(None)):
                hparam_dict[k] = 'None'
            if isinstance(v, list):
                hparam_dict[k] = torch.Tensor(v)
        hparam_dict['hidden_units'] = torch.Tensor(
            parent_kwargs['hidden_units'])
        self.writer.add_hparams(hparam_dict=hparam_dict, metric_dict={})

        # Set hyperparameters
        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.initial_latent_steps = parent_kwargs['initial_latent_steps']
        self.num_sequences = num_sequences
        self.num_steps = parent_kwargs['num_steps']
        self.batch_size = parent_kwargs['batch_size']
        self.latent_batch_size = parent_kwargs['latent_batch_size']
        self.start_steps = parent_kwargs['start_steps']
        self.grad_clip = parent_kwargs['grad_clip']
        self.training_log_interval = parent_kwargs['training_log_interval']
        self.learning_log_interval = parent_kwargs['learning_log_interval']

        # Mode action sampler
        self.mode_action_sampler = ModeActionSampler(self.latent,
                                                     device=self.device)

    def get_skill_action_pixel(self):
        obs_state_space = self.env.get_state_obs()
        action, info = self.policy.get_action(obs_state_space)
        return action

    def get_skill_action_state_rep(self, observation):
        action, info = self.policy.get_action(observation)
        return action

    def set_policy_skill(self, skill):
        self.policy.stochastic_policy.skill = skill

    def train_episode(self):
        self.episodes += 1
        episode_steps = 0
        episode_reward = 0.
        done = False
        state = self.env.reset()
        self.memory.set_initial_state(state)
        skill = np.random.randint(self.policy.stochastic_policy.skill_dim)
        #skill = np.random.choice([3, 4, 5, 7, 9], 1).item()
        self.set_policy_skill(skill)

        next_state = state
        while not done and episode_steps <= self.num_sequences + 2:
            action = self.get_skill_action_state_rep(next_state) if self.state_rep \
                else self.get_skill_action_pixel()
            next_state, reward, done, _ = self.env.step(action)
            self.steps += self.action_repeat
            episode_steps += self.action_repeat
            episode_reward += reward

            self.memory.append(action, skill, next_state, done)

            if self.is_update():
                if self.learning_steps < self.initial_latent_steps:
                    print('-' * 60)
                    print('Learning the disentangled model only...')
                    for _ in tqdm(range(self.initial_latent_steps)):
                        self.learning_steps += 1
                        self.learn_latent()
                    print('Finished learning the disentangled model')
                    print('-' * 60)

        print(f'episode: {self.episodes:<4}  '
              f'episode steps: {episode_steps:<4}  '
              f'skill: {skill:<4}  ')

        self.save_models()

    def learn_latent(self):
        # Sample sequence
        images_seq, actions_seq, skill_seq, dones_seq = \
            self.memory.sample_latent(self.latent_batch_size)

        # Calc loss
        latent_loss = self.calc_latent_loss(images_seq, actions_seq, skill_seq,
                                            dones_seq)

        # Backprop
        update_params(self.latent_optim, self.latent, latent_loss,
                      self.grad_clip)

        # Write net params
        if self._is_log(self.learning_log_interval * 5):
            self.latent.write_net_params(self.writer, self.learning_steps)

    def calc_latent_loss(self, images_seq, actions_seq, skill_seq, dones_seq):
        # Get features from images
        features_seq = self.latent.encoder(images_seq)

        # Sample from posterior dynamics
        ((latent1_post_samples, latent2_post_samples, mode_post_samples),
            (latent1_post_dists, latent2_post_dists, mode_post_dist)) = \
            self.latent.sample_posterior(actions_seq=actions_seq,
                                         features_seq=features_seq)

        # Sample from prior dynamics
        ((latent1_pri_samples, latent2_pri_samples, mode_pri_samples),
            (latent1_pri_dists, latent2_pri_dists, mode_pri_dist)) = \
            self.latent.sample_prior(features_seq)

        # KL divergence losses
        latent_kld = calc_kl_divergence(latent1_post_dists, latent1_pri_dists)
        latent1_dim = latent1_post_samples.size(2)
        seq_length = latent1_post_samples.size(1)
        latent_kld /= latent1_dim
        latent_kld /= seq_length

        mode_kld = calc_kl_divergence([mode_post_dist], [mode_pri_dist])
        mode_dim = mode_post_samples.size(2)
        mode_kld /= mode_dim
        kld_losses = mode_kld + latent_kld

        # Log likelihood loss of generated actions
        actions_seq_dists = self.latent.decoder(
            latent1_sample=latent1_post_samples,
            latent2_sample=latent2_post_samples,
            mode_sample=mode_post_samples)
        log_likelihood = actions_seq_dists.log_prob(actions_seq).mean(
            dim=0).mean()
        mse = torch.nn.functional.mse_loss(actions_seq_dists.loc, actions_seq)

        # Log likelihood loss of generated actions with latent dynamic priors and mode
        # posterior
        actions_seq_dists_mode = self.latent.decoder(
            latent1_sample=latent1_pri_samples.detach(),
            latent2_sample=latent2_pri_samples.detach(),
            mode_sample=mode_post_samples)
        ll_dyn_pri_mode_post = actions_seq_dists_mode.\
            log_prob(actions_seq).mean(dim=0).mean()

        # Log likelihood loss of generated actions with latent dynamic posteriors and
        # mode prior
        action_seq_dists_dyn = self.latent.decoder(
            latent1_sample=latent1_post_samples,
            latent2_sample=latent2_post_samples,
            mode_sample=mode_pri_samples)
        ll_dyn_post_mode_pri = action_seq_dists_dyn.log_prob(actions_seq).mean(
            dim=0).mean()

        # Maximum Mean Discrepancy (MMD)
        mode_pri_sample = mode_pri_samples[:, 0, :]
        mode_post_sample = mode_post_samples[:, 0, :]
        mmd_mode = self.compute_mmd_tutorial(mode_pri_sample, mode_post_sample)
        mmd_latent = 0
        #latent1_post_samples_trans = latent1_post_samples.transpose(0, 1)
        #latent1_pri_samples_trans = latent1_pri_samples.transpose(0, 1)
        #for idx in range(latent1_post_samples_trans.size(0)):
        #    mmd_latent += self.compute_mmd_tutorial(latent1_pri_samples_trans[idx],
        #                                            latent1_post_samples_trans[idx])
        latent1_post_samples_trans = latent1_post_samples.\
            view(latent1_post_samples.size(0), -1)
        latent1_pri_samples_trans = latent1_pri_samples.\
            view(latent1_pri_samples.size(0), -1)
        mmd_latent = self.compute_mmd_tutorial(latent1_pri_samples_trans,
                                               latent1_post_samples_trans)
        mmd_mode_weighted = mmd_mode
        mmd_latent_weighted = mmd_latent
        mmd_loss = mmd_latent_weighted + mmd_mode_weighted

        # MI-Gradient
        # m - data
        batch_size = mode_post_samples.size(0)
        features_actions_seq = torch.cat(
            [features_seq, actions_seq[:, :-1, :]], dim=2)
        xs = features_actions_seq.view(batch_size, -1)
        ys = mode_post_sample
        xs_ys = torch.cat([xs, ys], dim=1)
        gradient_estimator_m_data = entropy_surrogate(self.spectral_j, xs_ys) \
                                    - entropy_surrogate(self.spectral_m, ys)

        # m_pri - gen_data
        xs = mode_pri_sample
        ys = actions_seq_dists.loc.view(batch_size, -1)
        xs_ys = torch.cat([xs, ys], dim=1)
        gradient_estimator_m_gendata = entropy_surrogate(self.spectral_j, xs_ys) \
                                       - entropy_surrogate(self.spectral_m, ys)

        # m_post - latent_post
        #xs = mode_post_sample
        #gradient_estimator_mpost_latentpost = 0
        #for idx in range(latent1_post_samples.size(1)):
        #    ys = latent1_post_samples[:, idx, :]
        #    xs_ys = torch.cat([xs, ys], dim=1)
        #    single_estimator = entropy_surrogate(self.spectral_j, xs_ys) \
        #                       - entropy_surrogate(self.spectral_m, ys)
        #    gradient_estimator_mpost_latentpost += single_estimator
        xs = latent1_post_samples.view(batch_size, -1)
        ys = mode_post_sample
        xs_ys = torch.cat([xs, ys], dim=1)
        gradient_estimator_m_gendata = entropy_surrogate(self.spectral_j, xs_ys) \
                                       - entropy_surrogate(self.spectral_m, ys)

        # m-post - z-post
        #xs = mode_post_sample
        #ys = torch.cat([latent1_post_samples.view(batch_size, -1),
        #                latent2_post_samples.view(batch_size, -1)], dim=1)
        #xs_ys = torch.cat([xs, ys], dim=1)
        #gradient_estimator_m_post_z_post = entropy_surrogate(self.spectral_j, xs_ys) \
        #                                   - entropy_surrogate(self.spectral_m, ys)

        # Loss
        reg_weight = 1000.
        alpha = 0.99
        kld_info_weighted = (1. - alpha) * kld_losses
        mmd_info_weighted = (alpha + reg_weight - 1.) * mmd_loss

        reg_weight_mode = 100.
        alpha_mode = 1.
        mmd_mode_info_weighted = \
            (alpha_mode + reg_weight_mode - 1.) * mmd_mode_weighted
        kld_mode_info_weighted = (1. - alpha_mode) * mode_kld

        reg_weight_latent = 100.
        alpha_latent = 0
        mmd_latent_info_weighted = \
            (alpha_latent + reg_weight_latent -1.) * mmd_latent_weighted
        kld_latent_info_weighted = (1. - alpha_latent) * latent_kld

        loss_X = -log_likelihood
        loss_Z = kld_info_weighted + mmd_info_weighted

        latent_loss = mse + kld_losses - 1 * gradient_estimator_m_data
        #latent_loss = kld_info_weighted - log_likelihood + mmd_info_weighted
        #latent_loss = -log_likelihood \
        #              + 0.01 * latent_kld + mode_kld \
        #              - 1 * gradient_estimator_m_data \
        #              - 0 * gradient_estimator_m_gendata \
        #              #+ 1 * gradient_estimator_m_post_z_post
        #latent_loss = -log_likelihood + kld_info_weighted + mmd_info_weighted
        #latent_loss = mse\
        #              + mmd_mode_info_weighted \
        #              + kld_mode_info_weighted \
        #              + mmd_latent_info_weighted \
        #              + kld_latent_info_weighted \
        #              #+ 1 * gradient_estimator_mpost_latentpost

        latent_loss *= 10

        # Logging
        if self._is_log(self.learning_log_interval):

            # Reconstruction error
            reconst_error = (actions_seq - actions_seq_dists.loc) \
                .pow(2).mean(dim=(0, 1)).sum()
            self._summary_log('stats/reconst_error', reconst_error)
            print('reconstruction error: %f', reconst_error)

            reconst_err_mode_post = (actions_seq - actions_seq_dists_mode.loc)\
                .pow(2).mean(dim=(0,1)).sum()
            reconst_err_dyn_post = (actions_seq - action_seq_dists_dyn.loc)\
                .pow(2).mean(dim=(0, 1)).sum()
            self._summary_log('stats/reconst_error mode post',
                              reconst_err_mode_post)
            self._summary_log('stats/reconst_error dyn post',
                              reconst_err_dyn_post)

            # KL divergence
            mode_kldiv_standard = calc_kl_divergence([mode_post_dist],
                                                     [mode_pri_dist])
            seq_kldiv_standard = calc_kl_divergence(latent1_post_dists,
                                                    latent1_pri_dists)
            kldiv_standard = mode_kldiv_standard + seq_kldiv_standard

            self._summary_log('stats_kldiv_standard/kldiv_standard',
                              kldiv_standard)
            self._summary_log('stats_kldiv_standard/mode_kldiv_standard',
                              mode_kldiv_standard)
            self._summary_log('stats_kldiv_standard/seq_kldiv_standard',
                              seq_kldiv_standard)

            self._summary_log('stats_kldiv/mode_kldiv_used_for_loss', mode_kld)
            self._summary_log('stats_kldiv/latent_kldiv_used_for_loss',
                              latent_kld)
            self._summary_log('stats_kldiv/klddiv used for loss', kld_losses)

            # Log Likelyhood
            self._summary_log('stats/log-likelyhood', log_likelihood)
            self._summary_log('stats/mse', mse)

            # MMD
            self._summary_log('stats_mmd/mmd_weighted', mmd_info_weighted)
            self._summary_log('stats_mmd/kld_weighted', kld_info_weighted)
            self._summary_log('stats_mmd/mmd_mode_weighted', mmd_mode_weighted)
            self._summary_log('stats_mmd/mmd_latent_weighted',
                              mmd_latent_weighted)
            self._summary_log('stats_mmd_separated/mmd_mode_info_weighted',
                              mmd_mode_info_weighted)
            self._summary_log('stats_mmd_separated/kld_mode_info_weighted',
                              kld_mode_info_weighted)
            self._summary_log('stats_mmd_separated/mmd_latent_info_weighted',
                              mmd_latent_info_weighted)
            self._summary_log('stats_mmd_separated/kld_latent_info_weighted',
                              kld_latent_info_weighted)
            self._summary_log(
                'stats_mmd_separated/loss_latentZ',
                mmd_latent_info_weighted + kld_latent_info_weighted)
            self._summary_log('stats_mmd_separated/loss_modeZ',
                              mmd_mode_info_weighted + kld_mode_info_weighted)

            # MI-Grad
            self._summary_log('stats_mi/mi_grad_est m_pri generated data',
                              gradient_estimator_m_gendata)
            self._summary_log('stats_mi/mi_grad_est m_post data',
                              gradient_estimator_m_data)
            #self._summary_log('stats_mi/mi_grad_est m_post z_post',
            #                  gradient_estimator_m_post_z_post)

            # Loss
            self._summary_log('loss/network', latent_loss)

            # Save Model
            self.latent.save(os.path.join(self.model_dir, 'model.pth'))

            # Reconstruction Test
            rand_batch_idx = np.random.choice(actions_seq.size(0))
            self._reconstruction_post_test(rand_batch_idx, actions_seq,
                                           actions_seq_dists, images_seq)
            self._reconstruction_mode_post_test(
                rand_batch_idx=rand_batch_idx,
                actions_seq=actions_seq,
                mode_post_samples=mode_post_samples,
                latent1_pri_samples=latent1_pri_samples,
                latent2_pri_samples=latent2_pri_samples)
            self._reconstruction_dyn_post_test(rand_batch_idx, actions_seq,
                                               latent1_post_samples,
                                               latent2_post_samples,
                                               mode_pri_samples)

            # Latent Test
            self._plot_latent_mode_map(skill_seq, mode_post_samples)
            self._gen_mode_grid_graph(mode_post_samples)

        # Mode influence test
        if self._is_log(self.learning_log_interval * 10):
            self._gen_mode_grid_videos(mode_post_samples)

        return latent_loss

    def _gen_mode_grid_videos(self, mode_post_samples):
        seq_len = 200
        with torch.no_grad():
            modes = self._create_grid(mode_post_samples)

            for (mode_idx, mode) in enumerate(modes):

                obs = self.env.reset()
                img = self.env.render()
                img_seq = torch.from_numpy(img.astype(np.float)).transpose(
                    0, -1).unsqueeze(0)
                self.mode_action_sampler.reset(mode=mode.unsqueeze(0))
                for step in range(seq_len):
                    action = self.mode_action_sampler(
                        self.latent.encoder(
                            torch.Tensor(obs.astype(np.float)).to(
                                self.device).unsqueeze(0)))
                    obs, _, done, _ = self.env.step(
                        action.detach().cpu().numpy()[0])
                    img = self.env.render()
                    img = torch.from_numpy(img.astype(np.float))\
                        .transpose(0, -1).unsqueeze(0)
                    img_seq = torch.cat([img_seq, img], dim=0)

                self.writer.add_video('mode_generation_video/mode' +
                                      str(mode_idx),
                                      vid_tensor=img_seq.unsqueeze(0).float(),
                                      global_step=self.learning_steps)

    def _gen_mode_grid_graph(self, mode_post_samples):
        #TODO make the method universial in terms of envs
        assert len(self.env.action_space.shape) == 1,\
            'Method only works in MountainCar Case'
        seq_len = mode_post_samples.size(1)
        with torch.no_grad():
            modes = self._create_grid(mode_post_samples)

            for (mode_idx, mode) in enumerate(modes):

                obs = self.env.reset()
                self.mode_action_sampler.reset(mode=mode.unsqueeze(0))
                action = self.mode_action_sampler(
                    self.latent.encoder(
                        torch.from_numpy(obs).to(
                            self.device).unsqueeze(0).float()))
                action = action.detach().cpu().numpy()[0]
                obs_save = np.expand_dims(obs, axis=0)
                actions_save = [action]
                for step in range(seq_len):
                    obs, _, done, _ = self.env.step(action)

                    action = self.mode_action_sampler(
                        self.latent.encoder(
                            torch.from_numpy(obs).to(
                                self.device).unsqueeze(0).float()))
                    action = action.detach().cpu().numpy()[0]

                    actions_save = np.concatenate(
                        (actions_save, np.expand_dims(action, axis=0)), axis=0)
                    obs_save = np.concatenate(
                        (obs_save, np.expand_dims(obs, axis=0)), axis=0)
                plt.interactive(False)
                axes = plt.gca()
                axes.set_ylim([-1.5, 1.5])
                plt.plot(actions_save, label='actions')
                for dim in range(obs_save.shape[1]):
                    plt.plot(obs_save[:, dim], label='state_dim' + str(dim))
                fig = plt.gcf()
                self.writer.add_figure('mode_grid_plot_test/mode' +
                                       str(mode_idx),
                                       figure=fig,
                                       global_step=self.learning_steps)

    def _create_grid(self, mode_post_samples):
        mode_dim = mode_post_samples.size(2)
        grid_vec = torch.linspace(-2., 2., 4)
        grid_vec_list = [grid_vec] * mode_dim
        grid = torch.meshgrid(*grid_vec_list)
        modes = torch.stack(list(grid)).view(mode_dim, -1) \
            .transpose(0, -1).to(self.device)  # N x mode_dim
        return modes

    def _plot_latent_mode_map(self, skill_seq, mode_post_samples):
        with torch.no_grad():
            images_seq, actions_seq, skill_seq, dones_seq = \
                self.memory.sample_latent(128)
            features_seq = self.latent.encoder(images_seq)
            mode_post_dist = self.latent.mode_posterior(
                features_seq=features_seq.transpose(0, 1),
                actions_seq=actions_seq.transpose(0, 1))
            mode_post_sample = mode_post_dist.rsample()

            if mode_post_sample.size(1) == 2:
                colors = [
                    'b', 'g', 'r', 'c', 'm', 'y', 'k', 'darkorange', 'gray',
                    'lightgreen'
                ]
                skills = skill_seq.mean(
                    dim=1).detach().cpu().squeeze().numpy().astype(np.int)
                plt.interactive(False)
                axes = plt.gca()
                axes.set_ylim([-3, 3])
                axes.set_xlim([-3, 3])
                #for (idx, skill) in enumerate(skills):
                #    color = colors[skill.item()]
                #    plt.scatter(mode_post_samples[idx, 0].detach().cpu().numpy(),
                #                mode_post_samples[idx, 1].detach().cpu().numpy(),
                #                label=skill, c=color)
                for skill in range(10):
                    idx = skills == skill
                    color = colors[skill]
                    plt.scatter(mode_post_sample[idx,
                                                 0].detach().cpu().numpy(),
                                mode_post_sample[idx,
                                                 1].detach().cpu().numpy(),
                                label=skill,
                                c=color)

                axes.legend()
                axes.grid(True)
                fig = plt.gcf()
                self.writer.add_figure('Latent_test/mode mapping',
                                       fig,
                                       global_step=self.learning_steps)

    def _reconstruction_post_test(self, rand_batch_idx, actions_seq,
                                  actions_seq_dists, states):
        """
        Test reconstruction of inferred posterior
        Args:
            rand_batch_idx      : which part of batch to use
            actions_seq         : actions sequence sampled from data
            actions_seq_dists   : distribution of inferred posterior actions (reconstruction)
        """
        # Reconstruction test
        rand_batch = rand_batch_idx
        action_dim = actions_seq.size(2)
        gt_actions = actions_seq[rand_batch].detach().cpu()
        post_actions = actions_seq_dists.loc[rand_batch].detach().cpu()
        states = states[rand_batch].detach().cpu()
        for dim in range(action_dim):
            fig = self._reconstruction_test_plot(dim, gt_actions, post_actions,
                                                 states)
            self.writer.add_figure('Reconst_post_test/reconst test dim' +
                                   str(dim),
                                   fig,
                                   global_step=self.learning_steps)
            plt.clf()

    def _reconstruction_mode_post_test(self, rand_batch_idx, actions_seq,
                                       latent1_pri_samples,
                                       latent2_pri_samples, mode_post_samples):
        """
        Test if mode inference works
        Args:
            rand_batch_idx      : which part of batch to use
            actions_seq         : actions sequence sampled from data
            mode_post_samples   : Sample from the inferred mode posterior distribution
            latent1_pri_samples : Samples from the dynamics prior
            latent2_pri_samples : Samples from the dynamics prior
        """
        # Use random sample from batch
        rand_batch = rand_batch_idx

        # Decode
        actions_seq_dists = self.latent.decoder(
            latent1_sample=latent1_pri_samples[rand_batch],
            latent2_sample=latent2_pri_samples[rand_batch],
            mode_sample=mode_post_samples[rand_batch])

        # Reconstruction test
        action_dim = actions_seq.size(2)
        gt_actions = actions_seq[rand_batch].detach().cpu()
        post_actions = actions_seq_dists.loc.detach().cpu()

        # Plot
        for dim in range(action_dim):
            fig = self._reconstruction_test_plot(dim, gt_actions, post_actions)
            self.writer.add_figure('Reconst_mode_post_test/reconst test dim' +
                                   str(dim),
                                   fig,
                                   global_step=self.learning_steps)

    def _reconstruction_dyn_post_test(self, rand_batch_idx, actions_seq,
                                      latent1_post_samples,
                                      latent2_post_samples, mode_pri_samples):
        """
        Test the influence of the latent dynamics inference
        """
        # Use random sample from batch
        rand_batch = rand_batch_idx

        # Decode
        actions_seq_dists = self.latent.decoder(
            latent1_sample=latent1_post_samples[rand_batch],
            latent2_sample=latent2_post_samples[rand_batch],
            mode_sample=mode_pri_samples[rand_batch])

        # Reconstruction test
        action_dim = actions_seq.size(2)
        gt_actions = actions_seq[rand_batch].detach().cpu()
        post_actions = actions_seq_dists.loc.detach().cpu()

        # Plot
        for dim in range(action_dim):
            fig = self._reconstruction_test_plot(dim, gt_actions, post_actions)
            self.writer.add_figure('Reconst_dyn_post_test/reconst test dim' +
                                   str(dim),
                                   fig,
                                   global_step=self.learning_steps)

    def _reconstruction_test_plot(self,
                                  dim,
                                  gt_actions,
                                  post_actions,
                                  states=None):
        plt.interactive(False)
        axes = plt.gca()
        axes.set_ylim([-1.5, 1.5])
        plt.plot(gt_actions[:, dim].numpy())
        plt.plot(post_actions[:, dim].numpy())
        if states is not None:
            for dim in range(states.size(1)):
                plt.plot(states[:, dim].numpy())

        fig = plt.gcf()
        return fig

    def compute_kernel_tutorial(self, x, y):
        x_size = x.size(0)
        y_size = y.size(0)
        dim = x.size(1)
        x = x.unsqueeze(1)  # (x_size, 1, dim)
        y = y.unsqueeze(0)  # (1, y_size, dim)
        tiled_x = x.expand(x_size, y_size, dim)
        tiled_y = y.expand(x_size, y_size, dim)
        kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim)
        return torch.exp(-kernel_input)  # (x_size, y_size)

    def compute_mmd_tutorial(self, x, y):
        assert x.shape == y.shape
        x_kernel = self.compute_kernel_tutorial(x, x)
        y_kernel = self.compute_kernel_tutorial(y, y)
        xy_kernel = self.compute_kernel_tutorial(x, y)
        mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean()
        return mmd

    def save_models(self):
        path_name = os.path.join(self.model_dir, self.run_id)
        self.latent.save(path_name + 'model_state_dict.pth')
        torch.save(self.latent, path_name + 'whole_model.pth')
        #np.save(self.memory, os.path.join(self.model_dir, 'memory.pth'))

    def _is_log(self, log_interval):
        return True if self.learning_steps % log_interval == 0 else False

    def _summary_log(self, data_name, data):
        if type(data) == torch.Tensor:
            data = data.detach().cpu().item()
        self.writer.add_scalar(data_name, data, self.learning_steps)
示例#23
0
class SummaryWriter:
    def __init__(self, logdir, flush_secs=120):

        self.writer = TensorboardSummaryWriter(
            log_dir=logdir,
            purge_step=None,
            max_queue=10,
            flush_secs=flush_secs,
            filename_suffix='')

        self.global_step = None
        self.active = True

        # ------------------------------------------------------------------------
        # register add_* and set_* functions in summary module on instantiation
        # ------------------------------------------------------------------------
        this_module = sys.modules[__name__]
        list_of_names = dir(SummaryWriter)
        for name in list_of_names:

            # add functions (without the 'add' prefix)
            if name.startswith('add_'):
                setattr(this_module, name[4:], getattr(self, name))

            #  set functions
            if name.startswith('set_'):
                setattr(this_module, name, getattr(self, name))

    def set_global_step(self, value):
        self.global_step = value

    def set_active(self, value):
        self.active = value

    def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_audio(
                tag, snd_tensor, global_step=global_step, sample_rate=sample_rate, walltime=walltime)

    def add_custom_scalars(self, layout):
        if self.active:
            self.writer.add_custom_scalars(layout)

    def add_custom_scalars_marginchart(self, tags, category='default', title='untitled'):
        if self.active:
            self.writer.add_custom_scalars_marginchart(tags, category=category, title=title)

    def add_custom_scalars_multilinechart(self, tags, category='default', title='untitled'):
        if self.active:
            self.writer.add_custom_scalars_multilinechart(tags, category=category, title=title)

    def add_embedding(self, mat, metadata=None, label_img=None, global_step=None,
                      tag='default', metadata_header=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_embedding(
                mat, metadata=metadata, label_img=label_img, global_step=global_step,
                tag=tag, metadata_header=metadata_header)

    def add_figure(self, tag, figure, global_step=None, close=True, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_figure(
                tag, figure, global_step=global_step, close=close, walltime=walltime)

    def add_graph(self, model, input_to_model=None, verbose=False):
        if self.active:
            self.writer.add_graph(model, input_to_model=input_to_model, verbose=verbose)

    def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_histogram(
                tag, values, global_step=global_step, bins=bins,
                walltime=walltime, max_bins=max_bins)

    def add_histogram_raw(self, tag, min, max, num, sum, sum_squares,
                          bucket_limits, bucket_counts, global_step=None,
                          walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_histogram_raw(
                tag, min=min, max=max, num=num, sum=sum, sum_squares=sum_squares,
                bucket_limits=bucket_limits, bucket_counts=bucket_counts,
                global_step=global_step, walltime=walltime)

    def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_image(
                tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)

    def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,
                             walltime=None, rescale=1, dataformats='CHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_image_with_boxes(
                tag, img_tensor, box_tensor,
                global_step=global_step, walltime=walltime,
                rescale=rescale, dataformats=dataformats)

    def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_images(
                tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats)

    def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_mesh(
                tag, vertices, colors=colors, faces=faces, config_dict=config_dict,
                global_step=global_step, walltime=walltime)

    def add_onnx_graph(self, graph):
        if self.active:
            self.writer.add_onnx_graph(graph)

    def add_pr_curve(self, tag, labels, predictions, global_step=None,
                     num_thresholds=127, weights=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_pr_curve(
                tag, labels, predictions, global_step=global_step,
                num_thresholds=num_thresholds, weights=weights, walltime=walltime)

    def add_pr_curve_raw(self, tag, true_positive_counts,
                         false_positive_counts,
                         true_negative_counts,
                         false_negative_counts,
                         precision,
                         recall,
                         global_step=None,
                         num_thresholds=127,
                         weights=None,
                         walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_pr_curve_raw(
                tag, true_positive_counts,
                false_positive_counts,
                true_negative_counts,
                false_negative_counts,
                precision,
                recall,
                global_step=global_step,
                num_thresholds=num_thresholds,
                weights=weights,
                walltime=walltime)

    def add_scalar(self, tag, scalar_value, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_scalar(
                tag, scalar_value, global_step=global_step, walltime=walltime)

    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_scalars(
                main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime)

    def add_text(self, tag, text_string, global_step=None, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_text(
                tag, text_string, global_step=global_step, walltime=walltime)

    def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
        if self.active:
            global_step = self.global_step if global_step is None else global_step
            self.writer.add_video(
                tag, vid_tensor, global_step=global_step, fps=fps, walltime=walltime)

    def close(self):
        self.writer.close()

    def __enter__(self):
        return self.writer.__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        return self.writer.__exit__(exc_type, exc_val, exc_tb)
示例#24
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, default='data/bair')
    parser.add_argument('--model_path', type=str, default='model/bair')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--horizon', type=int, default=10)
    parser.add_argument('--cpu_workers', type=int, default=4)
    parser.add_argument('--gpu_id', type=int, default=0)
    parser.add_argument('--model_name', type=str, default='cdna')
    parser.add_argument('--load_point', type=int, default=10)
    parser.add_argument('--no-gif', dest='save_gif', action='store_false')
    parser.set_defaults(save_gif=True)

    args = parser.parse_args()

    device = 'cuda:%d' % args.gpu_id if torch.cuda.device_count() > 0 else 'cpu'

    # dataset setup
    test_set = VideoDataset(args.data_path, 'test', args.horizon, fix_start=True)

    config = test_set.get_config()
    H, W, C = config['observations']
    A = config['actions'][0]
    T = args.horizon

    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.cpu_workers)

    # model setup
    if args.model_name == 'cdna':
        model = CDNA(T, H, W, C, A)
    elif args.model_name == 'etd':
        model = ETD(H, W, C, A, T, 5)
    elif args.model_name == 'etds':
        model = ETDS(H, W, C, A, T, 5)
    elif args.model_name == 'etdm':
        model = ETDM(H, W, C, A, T, 5)
    elif args.model_name == 'etdsd':
        model = ETDSD(H, W, C, A, T, 5)

    model.to(device)

    model_path = os.path.join(args.model_path, '{}_10'.format(args.model_name))

    load_model(model, os.path.join(model_path, '{}_{}.pt'.format(args.model_name, args.load_point)), eval_mode=True)

    # tensorboard
    writer = SummaryWriter()


    gif_path = os.path.join(model_path, 'test_{}'.format(args.horizon))
    if not os.path.exists(gif_path):
        os.makedirs(gif_path)

    losses = []
    videos = []
    inference_times = []

    with torch.no_grad():
        for j, data in enumerate(test_loader):
            observations = data['observations']
            actions = data['actions']

            # B x T ==> T x B
            observations = torch.transpose(observations, 0, 1).to(device)
            actions = torch.transpose(actions, 0, 1).to(device)

            start_time = time.time()
            predicted_observations = model(observations[0], actions)
            inference_times.append((time.time() - start_time) * 1000)

            video = torch.cat([observations[0, 0].unsqueeze(0), predicted_observations[0 : T - 1, 0]]) # tensor[T, C, H, W]
            videos.append(video.unsqueeze(0).detach().cpu())

            if args.save_gif:
                torch_save_gif(os.path.join(gif_path, "{}.gif".format(j)), video.detach().cpu(), fps=10)

            loss = mse_loss(observations, predicted_observations).item() / args.batch_size
            losses.append(loss)

            del loss, observations, actions, predicted_observations # clear the memory

    
    videos = torch.cat(videos, 0)
    writer.add_video('test_video_{}_{}'.format(args.model_name, args.horizon), videos, global_step=0, fps=10)

    print("-" * 50)
    print("mean loss in test set is {}, std is {}".format(np.mean(losses), np.std(losses)))
    print("mean inference time in test set is {}, std is {}".format(np.mean(inference_times), np.std(inference_times)))
    print("-" * 50)
示例#25
0
def plot_2d_or_3d_image(
    data: Union[NdarrayTensor, List[NdarrayTensor]],
    step: int,
    writer: SummaryWriter,
    index: int = 0,
    max_channels: int = 1,
    frame_dim: int = -3,
    max_frames: int = 24,
    tag: str = "output",
) -> None:
    """Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.

    Note:
        Plot 3D or 2D image(with more than 3 channels) as separate images.
        And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.

    Args:
        data: target data to be plotted as image on the TensorBoard.
            The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions,
            and only plot the first in the batch.
        step: current step to plot in a chart.
        writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image.
        index: plot which element in the input data batch, default is the first element.
        max_channels: number of channels to plot.
        frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,
            expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)
        max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
        tag: tag of the plotted image on TensorBoard.
    """
    data_index = data[index]
    # as the `d` data has no batch dim, reduce the spatial dim index if positive
    frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim

    d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(
        data_index, torch.Tensor) else data_index

    if d.ndim == 2:
        d = rescale_array(d, 0, 1)  # type: ignore
        dataformats = "HW"
        writer.add_image(f"{tag}_{dataformats}",
                         d,
                         step,
                         dataformats=dataformats)
        return

    if d.ndim == 3:
        if d.shape[0] == 3 and max_channels == 3:  # RGB
            dataformats = "CHW"
            writer.add_image(f"{tag}_{dataformats}",
                             d,
                             step,
                             dataformats=dataformats)
            return
        dataformats = "HW"
        for j, d2 in enumerate(d[:max_channels]):
            d2 = rescale_array(d2, 0, 1)
            writer.add_image(f"{tag}_{dataformats}_{j}",
                             d2,
                             step,
                             dataformats=dataformats)
        return

    if d.ndim >= 4:
        spatial = d.shape[-3:]
        d = d.reshape([-1] + list(spatial))
        if d.shape[
                0] == 3 and max_channels == 3 and has_tensorboardx and isinstance(
                    writer, SummaryWriterX):  # RGB
            # move the expected frame dim to the end as `T` dim for video
            d = np.moveaxis(d, frame_dim, -1)
            writer.add_video(tag,
                             d[None],
                             step,
                             fps=max_frames,
                             dataformats="NCHWT")
            return
        # scale data to 0 - 255 for visualization
        max_channels = min(max_channels, d.shape[0])
        d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]],
                     axis=0)
        # will plot every channel as a separate GIF image
        add_animated_gif(writer,
                         f"{tag}_HWD",
                         d,
                         max_out=max_channels,
                         frame_dim=frame_dim,
                         global_step=step)
        return
示例#26
0
文件: logger.py 项目: renmengye/curl
class Logger(object):
    def __init__(self, log_dir, use_tb=True, config="rl"):
        self._log_dir = log_dir
        if use_tb:
            tb_dir = os.path.join(log_dir, "tb")
            if os.path.exists(tb_dir):
                shutil.rmtree(tb_dir)
            self._sw = SummaryWriter(tb_dir)
        else:
            self._sw = None
        self._train_mg = MetersGroup(os.path.join(log_dir, "train.log"),
                                     formating=FORMAT_CONFIG[config]["train"])
        self._eval_mg = MetersGroup(os.path.join(log_dir, "eval.log"),
                                    formating=FORMAT_CONFIG[config]["eval"])

    def _try_sw_log(self, key, value, step):
        if self._sw is not None:
            self._sw.add_scalar(key, value, step)

    def _try_sw_log_image(self, key, image, step):
        if self._sw is not None:
            assert image.dim() == 3
            grid = torchvision.utils.make_grid(image.unsqueeze(1))
            self._sw.add_image(key, grid, step)

    def _try_sw_log_video(self, key, frames, step):
        if self._sw is not None:
            frames = torch.from_numpy(np.array(frames))
            frames = frames.unsqueeze(0)
            self._sw.add_video(key, frames, step, fps=30)

    def _try_sw_log_histogram(self, key, histogram, step):
        if self._sw is not None:
            self._sw.add_histogram(key, histogram, step)

    def log(self, key, value, step, n=1):
        assert key.startswith("train") or key.startswith("eval")
        if type(value) == torch.Tensor:
            value = value.item()
        self._try_sw_log(key, value / n, step)
        mg = self._train_mg if key.startswith("train") else self._eval_mg
        mg.log(key, value, n)

    def log_param(self, key, param, step):
        self.log_histogram(key + "_w", param.weight.data, step)
        if hasattr(param.weight, "grad") and param.weight.grad is not None:
            self.log_histogram(key + "_w_g", param.weight.grad.data, step)
        if hasattr(param, "bias"):
            self.log_histogram(key + "_b", param.bias.data, step)
            if hasattr(param.bias, "grad") and param.bias.grad is not None:
                self.log_histogram(key + "_b_g", param.bias.grad.data, step)

    def log_image(self, key, image, step):
        assert key.startswith("train") or key.startswith("eval")
        self._try_sw_log_image(key, image, step)

    def log_video(self, key, frames, step):
        assert key.startswith("train") or key.startswith("eval")
        self._try_sw_log_video(key, frames, step)

    def log_histogram(self, key, histogram, step):
        assert key.startswith("train") or key.startswith("eval")
        self._try_sw_log_histogram(key, histogram, step)

    def dump(self, step):
        self._train_mg.dump(step, "train")
        self._eval_mg.dump(step, "eval")
示例#27
0
                # generate videos
                videos = []
                for i in range(params['validation_samples']):
                    pbar.set_description(
                        f'Rendering rollout {i + 1}/{params["validation_samples"]}',
                        refresh=True)
                    initial = next(iter(validation_loader))
                    graph = initial[0]
                    graph.to(device=device)
                    pos = infer_trajectory(network, graph,
                                           params['validation_steps'])
                    video = render_rollout_tensor(pos, title=f'step {step}')
                    videos.append(video)

                videos = torch.cat(videos, dim=0)
                writer.add_video('rollout', videos, step, fps=60)
                pbar.set_description(refresh=True)

        if (step) % params['model_save_interval'] == 0:
            torch.save(
                {
                    'step': step,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'model_state_dict': network.cpu().state_dict(),
                    'loss': loss
                }, join(args.save_dir, f'{args.run_name}-{step}.pt'))
            network.to(device=device)

        if step >= params['steps']:
            break
示例#28
0
class Trainer(object):
    def __init__(self,
                 image_sampler,
                 video_sampler,
                 log_interval,
                 train_batches,
                 log_folder,
                 use_cuda=False,
                 use_infogan=True,
                 use_categories=True):

        self.use_categories = use_categories

        self.gan_criterion = nn.BCEWithLogitsLoss()
        self.category_criterion = nn.CrossEntropyLoss()

        self.image_sampler = image_sampler
        self.video_sampler = video_sampler

        self.video_batch_size = self.video_sampler.batch_size
        self.image_batch_size = self.image_sampler.batch_size

        self.log_interval = log_interval
        self.train_batches = train_batches

        self.log_folder = log_folder

        self.use_cuda = use_cuda
        self.use_infogan = use_infogan

        self.image_enumerator = None
        self.video_enumerator = None

        self.writer = SummaryWriter(self.log_folder)

    @staticmethod
    def ones_like(tensor, val=1.):
        return Variable(T.FloatTensor(tensor.size()).fill_(val),
                        requires_grad=False)

    @staticmethod
    def zeros_like(tensor, val=0.):
        return Variable(T.FloatTensor(tensor.size()).fill_(val),
                        requires_grad=False)

    def compute_gan_loss(self, discriminator, sample_true, sample_fake,
                         is_video):
        real_batch = sample_true()

        batch_size = real_batch['images'].size(0)
        fake_batch, generated_categories = sample_fake(batch_size)

        real_labels, real_categorical = discriminator(
            Variable(real_batch['images']))
        fake_labels, fake_categorical = discriminator(fake_batch)

        fake_gt, real_gt = self.get_gt_for_discriminator(batch_size, real=0.)

        l_discriminator = self.gan_criterion(real_labels, real_gt) + \
                          self.gan_criterion(fake_labels, fake_gt)

        # update image discriminator here

        # sample again for videos

        # update video discriminator

        # sample again
        # - videos
        # - images

        # l_vidoes + l_images -> l
        # l.backward()
        # opt.step()

        #  sample again and compute for generator

        fake_gt = self.get_gt_for_generator(batch_size)
        # to real_gt
        l_generator = self.gan_criterion(fake_labels, fake_gt)

        if is_video:

            # Ask the video discriminator to learn categories from training videos
            categories_gt = Variable(
                torch.squeeze(real_batch['categories'].long()))
            l_discriminator += self.category_criterion(real_categorical,
                                                       categories_gt)

            if self.use_infogan:
                # Ask the generator to generate categories recognizable by the discriminator
                l_generator += self.category_criterion(fake_categorical,
                                                       generated_categories)

        return l_generator, l_discriminator

    def sample_real_image_batch(self):
        if self.image_enumerator is None:
            self.image_enumerator = enumerate(self.image_sampler)

        batch_idx, batch = next(self.image_enumerator)
        b = batch
        if self.use_cuda:
            for k, v in batch.items():
                b[k] = v.cuda()

        if batch_idx == len(self.image_sampler) - 1:
            self.image_enumerator = enumerate(self.image_sampler)

        return b

    def sample_real_video_batch(self):
        if self.video_enumerator is None:
            self.video_enumerator = enumerate(self.video_sampler)

        batch_idx, batch = next(self.video_enumerator)
        b = batch
        if self.use_cuda:
            for k, v in batch.items():
                b[k] = v.cuda()

        if batch_idx == len(self.video_sampler) - 1:
            self.video_enumerator = enumerate(self.video_sampler)

        return b

    def train_discriminator(self, discriminator, sample_true, sample_fake, opt,
                            batch_size, use_categories):
        opt.zero_grad()

        real_batch = sample_true()
        batch = Variable(real_batch['images'], requires_grad=False)

        # util.show_batch(batch.data)

        fake_batch, generated_categories = sample_fake(batch_size)

        real_labels, real_categorical = discriminator(batch)
        fake_labels, fake_categorical = discriminator(fake_batch.detach())

        ones = self.ones_like(real_labels)
        zeros = self.zeros_like(fake_labels)

        l_discriminator = self.gan_criterion(real_labels, ones) + \
                          self.gan_criterion(fake_labels, zeros)

        if use_categories:
            # Ask the video discriminator to learn categories from training videos
            categories_gt = Variable(torch.squeeze(
                real_batch['categories'].long()),
                                     requires_grad=False)
            l_discriminator += self.category_criterion(
                real_categorical.squeeze(), categories_gt)

        l_discriminator.backward()
        opt.step()

        return l_discriminator

    def train_generator(self, image_discriminator, video_discriminator,
                        sample_fake_images, sample_fake_videos, opt):

        opt.zero_grad()

        # train on images

        fake_batch, generated_categories = sample_fake_images(
            self.image_batch_size)
        fake_labels, fake_categorical = image_discriminator(fake_batch)
        all_ones = self.ones_like(fake_labels)

        l_generator = self.gan_criterion(fake_labels, all_ones)

        # train on videos

        fake_batch, generated_categories = sample_fake_videos(
            self.video_batch_size)
        fake_labels, fake_categorical = video_discriminator(fake_batch)
        all_ones = self.ones_like(fake_labels)

        l_generator += self.gan_criterion(fake_labels, all_ones)

        if self.use_infogan:
            # Ask the generator to generate categories recognizable by the discriminator
            l_generator += self.category_criterion(fake_categorical.squeeze(),
                                                   generated_categories)

        l_generator.backward()
        opt.step()

        return l_generator

    def train(self, generator, image_discriminator, video_discriminator):
        if self.use_cuda:
            generator.cuda()
            image_discriminator.cuda()
            video_discriminator.cuda()

        # create optimizers
        opt_generator = optim.Adam(generator.parameters(),
                                   lr=0.0002,
                                   betas=(0.5, 0.999),
                                   weight_decay=0.00001)
        opt_image_discriminator = optim.Adam(image_discriminator.parameters(),
                                             lr=0.0002,
                                             betas=(0.5, 0.999),
                                             weight_decay=0.00001)
        opt_video_discriminator = optim.Adam(video_discriminator.parameters(),
                                             lr=0.0002,
                                             betas=(0.5, 0.999),
                                             weight_decay=0.00001)

        # training loop

        def sample_fake_image_batch(batch_size):
            return generator.sample_images(batch_size)

        def sample_fake_video_batch(batch_size):
            return generator.sample_videos(batch_size)

        def init_logs():
            return {'l_gen': 0, 'l_image_dis': 0, 'l_video_dis': 0}

        batch_num = 0

        logs = init_logs()

        start_time = time.time()

        while True:
            generator.train()
            image_discriminator.train()
            video_discriminator.train()

            opt_generator.zero_grad()

            opt_video_discriminator.zero_grad()

            # train image discriminator
            l_image_dis = self.train_discriminator(
                image_discriminator,
                self.sample_real_image_batch,
                sample_fake_image_batch,
                opt_image_discriminator,
                self.image_batch_size,
                use_categories=False)

            # train video discriminator
            l_video_dis = self.train_discriminator(
                video_discriminator,
                self.sample_real_video_batch,
                sample_fake_video_batch,
                opt_video_discriminator,
                self.video_batch_size,
                use_categories=self.use_categories)

            # train generator
            l_gen = self.train_generator(image_discriminator,
                                         video_discriminator,
                                         sample_fake_image_batch,
                                         sample_fake_video_batch,
                                         opt_generator)

            logs['l_gen'] += l_gen.data

            logs['l_image_dis'] += l_image_dis.data
            logs['l_video_dis'] += l_video_dis.data

            batch_num += 1

            if batch_num % self.log_interval == 0:

                log_string = "Batch %d" % batch_num
                for k, v in logs.items():
                    log_string += " [%s] %5.3f" % (k, v / self.log_interval)

                log_string += ". Took %5.2f" % (time.time() - start_time)

                print(log_string)

                # log loss
                for tag, value in list(logs.items()):
                    self.writer.add_scalar(tag, value / self.log_interval,
                                           batch_num)

                logs = init_logs()
                start_time = time.time()

                generator.eval()

                images, _ = sample_fake_image_batch(self.image_batch_size)
                # log images
                self.writer.add_images('images-objs', images[:, 0:3, :, :],
                                       batch_num)
                self.writer.add_images('images-background',
                                       images[:, 3:6, :, :], batch_num)

                videos, _ = sample_fake_video_batch(self.video_batch_size)
                # log videos
                vid_obj = videos[:, 0:3, :, :, :].permute(0, 2, 1, 3, 4)
                vid_background = videos[:, 3:6, :, :, :].permute(0, 2, 1, 3, 4)
                self.writer.add_video("video_obj", vid_obj, batch_num, fps=35)
                self.writer.add_video("video_background",
                                      vid_background,
                                      batch_num,
                                      fps=35)

                torch.save(
                    generator,
                    os.path.join(self.log_folder,
                                 'generator_%05d.pytorch' % batch_num))

            if batch_num >= self.train_batches:
                torch.save(
                    generator,
                    os.path.join(self.log_folder,
                                 'generator_%05d.pytorch' % batch_num))
                break
示例#29
0
class Logger:
    def __init__(self, log_dir, n_logged_samples=10, summary_writer=None):
        self._log_dir = log_dir
        print('########################')
        print('logging outputs to ', log_dir)
        print('########################')
        self._n_logged_samples = n_logged_samples
        self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1)

    def log_scalar(self, scalar, name, step_):
        self._summ_writer.add_scalar('{}'.format(name), scalar, step_)

    def log_scalars(self, scalar_dict, group_name, step, phase):
        """Will log all scalars in the same plot."""
        self._summ_writer.add_scalars('{}_{}'.format(group_name, phase),
                                      scalar_dict, step)

    def log_image(self, image, name, step):
        assert (len(image.shape) == 3)  # [C, H, W]
        self._summ_writer.add_image('{}'.format(name), image, step)

    def log_video(self, video_frames, name, step, fps=10):
        assert len(
            video_frames.shape
        ) == 5, "Need [N, T, C, H, W] input tensor for video logging!"
        self._summ_writer.add_video('{}'.format(name),
                                    video_frames,
                                    step,
                                    fps=fps)

    def log_paths_as_videos(self,
                            paths,
                            step,
                            max_videos_to_save=2,
                            fps=10,
                            video_title='video'):

        # reshape the rollouts
        videos = [np.transpose(p['image_obs'], [0, 3, 1, 2]) for p in paths]

        # max rollout length
        max_videos_to_save = np.min([max_videos_to_save, len(videos)])
        max_length = videos[0].shape[0]
        for i in range(max_videos_to_save):
            if videos[i].shape[0] > max_length:
                max_length = videos[i].shape[0]

        # pad rollouts to all be same length
        for i in range(max_videos_to_save):
            if videos[i].shape[0] < max_length:
                padding = np.tile([videos[i][-1]],
                                  (max_length - videos[i].shape[0], 1, 1, 1))
                videos[i] = np.concatenate([videos[i], padding], 0)

        # log videos to tensorboard event file
        print("Logging videos")
        videos = np.stack(videos[:max_videos_to_save], 0)
        self.log_video(videos, video_title, step, fps=fps)

    def log_figures(self, figure, name, step, phase):
        """figure: matplotlib.pyplot figure handle"""
        assert figure.shape[
            0] > 0, "Figure logging requires input shape [batch x figures]!"
        self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step)

    def log_figure(self, figure, name, step, phase):
        """figure: matplotlib.pyplot figure handle"""
        self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step)

    def log_graph(self, array, name, step, phase):
        """figure: matplotlib.pyplot figure handle"""
        im = plot_graph(array)
        self._summ_writer.add_image('{}_{}'.format(name, phase), im, step)

    def dump_scalars(self, log_path=None):
        log_path = os.path.join(
            self._log_dir,
            "scalar_data.json") if log_path is None else log_path
        self._summ_writer.export_scalars_to_json(log_path)

    def flush(self):
        self._summ_writer.flush()
示例#30
0
class Trainer:
    def __init__(self,
                 save_name='train',
                 description="Default model trainer",
                 drop_last=False,
                 allow_val_grad=False):
        now = datetime.datetime.now()
        parser = argparse.ArgumentParser(description=description)
        parser.add_argument('experiment_file',
                            type=str,
                            help='path to YAML experiment config file')
        parser.add_argument(
            '--save_path',
            type=str,
            default='',
            help=
            'path to place model save file in during training (overwrites config)'
        )
        parser.add_argument('--device',
                            type=int,
                            default=None,
                            nargs='+',
                            help='target device (uses all if not specified)')
        args = parser.parse_args()
        self._config = parse_basic_config(args.experiment_file)
        save_config = copy.deepcopy(self._config)
        if args.save_path:
            self._config['save_path'] = args.save_path

        # initialize device
        def_device = 0 if args.device is None else args.device[0]
        self._device = torch.device("cuda:{}".format(def_device))
        self._device_list = args.device
        self._allow_val_grad = allow_val_grad

        # parse dataset class and create train/val loaders
        dataset_class = get_dataset(self._config['dataset'].pop('type'))
        dataset = dataset_class(**self._config['dataset'], mode='train')
        val_dataset = dataset_class(**self._config['dataset'], mode='val')
        self._train_loader = DataLoader(
            dataset,
            batch_size=self._config['batch_size'],
            shuffle=True,
            num_workers=self._config.get('loader_workers', cpu_count()),
            drop_last=drop_last,
            worker_init_fn=lambda w: np.random.seed(
                np.random.randint(2**29) + w))
        self._val_loader = DataLoader(
            val_dataset,
            batch_size=self._config['batch_size'],
            shuffle=True,
            num_workers=min(1, self._config.get('loader_workers',
                                                cpu_count())),
            drop_last=True,
            worker_init_fn=lambda w: np.random.seed(
                np.random.randint(2**29) + w))

        # set of file saving
        save_dir = os.path.join(
            self._config.get('save_path', './'),
            '{}_ckpt-{}-{}_{}-{}-{}'.format(save_name, now.hour, now.minute,
                                            now.day, now.month, now.year))
        save_dir = os.path.expanduser(save_dir)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        with open(os.path.join(save_dir, 'config.yaml'), 'w') as f:
            yaml.dump(save_config, f, default_flow_style=False)
        self._writer = SummaryWriter(log_dir=os.path.join(save_dir, 'log'))
        self._save_fname = os.path.join(save_dir, 'model_save')
        self._step = None

    @property
    def config(self):
        return copy.deepcopy(self._config)

    def train(self,
              model,
              train_fn,
              weights_fn=None,
              val_fn=None,
              save_fn=None,
              optim_weights=None):
        # wrap model in DataParallel if needed and transfer to correct device
        if self.device_count > 1:
            model = nn.DataParallel(model, device_ids=self.device_list)
        model = model.to(self._device)

        # initializer optimizer and lr scheduler
        optim_weights = optim_weights if optim_weights is not None else model.parameters(
        )
        optimizer, scheduler = self._build_optimizer_and_scheduler(
            optim_weights)

        # initialize constants:
        epochs = self._config.get('epochs', 1)
        vlm_alpha = self._config.get('vlm_alpha', 0.6)
        log_freq = self._config.get('log_freq', 20)
        self._img_log_freq = img_log_freq = self._config.get(
            'img_log_freq', 500)
        assert img_log_freq % log_freq == 0, "log_freq must divide img_log_freq!"
        save_freq = self._config.get('save_freq', 5000)

        if val_fn is None:
            val_fn = train_fn

        self._step = 0
        train_stats = {'loss': 0}
        val_iter = iter(self._val_loader)
        vl_running_mean = None
        for e in range(epochs):
            for inputs in self._train_loader:
                self._zero_grad(optimizer)
                loss_i, stats_i = train_fn(model, self._device, *inputs)
                self._step_optim(loss_i, self._step, optimizer)

                # calculate iter stats
                mod_step = self._step % log_freq
                train_stats['loss'] = (self._loss_to_scalar(loss_i) +
                                       mod_step * train_stats['loss']) / (
                                           mod_step + 1)
                for k, v in stats_i.items():
                    if isinstance(v, torch.Tensor):
                        assert len(
                            v.shape) >= 4, "assumes 4dim BCHW image tensor!"
                        train_stats[k] = v
                    if k not in train_stats:
                        train_stats[k] = 0
                    train_stats[k] = (v + mod_step * train_stats[k]) / (
                        mod_step + 1)

                if mod_step == 0:
                    try:
                        val_inputs = next(val_iter)
                    except StopIteration:
                        val_iter = iter(self._val_loader)
                        val_inputs = next(val_iter)

                    if self._allow_val_grad:
                        model = model.eval()
                        val_loss, val_stats = val_fn(model, self._device,
                                                     *val_inputs)
                        model = model.train()
                        val_loss = self._loss_to_scalar(val_loss)
                    else:
                        with torch.no_grad():
                            model = model.eval()
                            val_loss, val_stats = val_fn(
                                model, self._device, *val_inputs)
                            model = model.train()
                            val_loss = self._loss_to_scalar(val_loss)

                    # update running mean stat
                    if vl_running_mean is None:
                        vl_running_mean = val_loss
                    vl_running_mean = val_loss * vlm_alpha + vl_running_mean * (
                        1 - vlm_alpha)

                    self._writer.add_scalar('loss/val', val_loss, self._step)
                    for stats_dict, mode in zip([train_stats, val_stats],
                                                ['train', 'val']):
                        for k, v in stats_dict.items():
                            if isinstance(v, torch.Tensor
                                          ) and self.step % img_log_freq == 0:
                                if len(v.shape) == 5:
                                    self._writer.add_video(
                                        '{}/{}'.format(k, mode), v.cpu(),
                                        self._step)
                                else:
                                    v_grid = torchvision.utils.make_grid(
                                        v.cpu(), padding=5)
                                    self._writer.add_image(
                                        '{}/{}'.format(k, mode), v_grid,
                                        self._step)
                            elif not isinstance(v, torch.Tensor):
                                self._writer.add_scalar(
                                    '{}/{}'.format(k, mode), v, self._step)

                    # add learning rate parameter to log
                    lrs = np.mean([p['lr'] for p in optimizer.param_groups])
                    self._writer.add_scalar('lr', lrs, self._step)

                    # flush to disk and print
                    self._writer.file_writer.flush()
                    print(
                        'epoch {3}/{4}, step {0}: loss={1:.4f} \t val loss={2:.4f}'
                        .format(self._step, train_stats['loss'],
                                vl_running_mean, e, epochs))
                else:
                    print('step {0}: loss={1:.4f}'.format(
                        self._step, train_stats['loss']),
                          end='\r')
                self._step += 1

                if self._step % save_freq == 0:
                    if save_fn is not None:
                        save_fn(self._save_fname, self._step)
                    else:
                        save_module = model
                        if weights_fn is not None:
                            save_module = weights_fn()
                        elif isinstance(model, nn.DataParallel):
                            save_module = model.module
                        torch.save(
                            save_module,
                            self._save_fname + '-{}.pt'.format(self._step))
                    if self._config.get('save_optim', False):
                        torch.save(
                            optimizer.state_dict(), self._save_fname +
                            '-optim-{}.pt'.format(self._step))
            scheduler.step(val_loss=vl_running_mean)

    @property
    def device_count(self):
        if self._device_list is None:
            return torch.cuda.device_count()
        return len(self._device_list)

    @property
    def device_list(self):
        if self._device_list is None:
            return [i for i in range(torch.cuda.device_count())]
        return copy.deepcopy(self._device_list)

    @property
    def device(self):
        return copy.deepcopy(self._device)

    def _build_optimizer_and_scheduler(self, optim_weights):
        optimizer = torch.optim.Adam(optim_weights,
                                     self._config['lr'],
                                     weight_decay=self._config.get(
                                         'weight_decay', 0))
        return optimizer, build_scheduler(optimizer,
                                          self._config.get('lr_schedule', {}))

    def _step_optim(self, loss, step, optimizer):
        loss.backward()
        optimizer.step()

    def _zero_grad(self, optimizer):
        optimizer.zero_grad()

    def _loss_to_scalar(self, loss):
        return loss.item()

    @property
    def step(self):
        if self._step is None:
            raise Exception("Optimization has not begun!")
        return self._step

    @property
    def is_img_log_step(self):
        return self._step % self._img_log_freq == 0