class TensorboardSummary(object): def __init__(self, directory): self.directory = directory self.writer = SummaryWriter(log_dir=os.path.join(self.directory)) def add_scalar(self, *args): self.writer.add_scalar(*args) def visualize_video(self, opt, global_step, video, name): video_transpose = video.permute(0, 2, 1, 3, 4) # BxTxCxHxW video_reshaped = video_transpose.flatten(0, 1) # (B+T)xCxHxW # image_range = opt.td #+ opt.num_targets image_range = video.shape[2] grid_image = make_grid( video_reshaped[:3 * image_range, :, :, :].clone().cpu().data, image_range, normalize=True) self.writer.add_image( 'Video/Scale {}/{}_unfold'.format(opt.scale_idx, name), grid_image, global_step) norm_range(video_transpose) self.writer.add_video('Video/Scale {}/{}'.format(opt.scale_idx, name), video_transpose[:3], global_step) def visualize_image(self, opt, global_step, ןimages, name): grid_image = make_grid(ןimages[:3, :, :, :].clone().cpu().data, 3, normalize=True) self.writer.add_image('Image/Scale {}/{}'.format(opt.scale_idx, name), grid_image, global_step)
class Logger: def __init__(self, log_dir, summary_writer=None): self.writer = SummaryWriter(log_dir, flush_secs=1, max_queue=20) self.step = 0 def add_graph(self, model=None, input: tuple = None): self.writer.add_graph(model, input) self.flush() def log_scalar(self, scalar, name, step_): self.writer.add_scalar('{}'.format(name), scalar, step_) def log_scalars(self, scalar_dict, step, phase='Train_'): """Will log all scalars in the same plot.""" self.writer.add_scalars(phase, scalar_dict, step) def log_video(self, video_frames, name, step, fps=10): assert len( video_frames.shape ) == 5, "Need [N, T, C, H, W] input tensor for video logging!" self.writer.add_video('{}'.format(name), video_frames, step, fps=fps) def flush(self): self.writer.flush() def close(self): self.writer.close()
class Tensorboard: def __init__(self, config: ConfigFactory) -> None: self.config = config self.writer = SummaryWriter(self.config.tensorboard_path()) def write_scalar(self, title, value, iteration) -> None: self.writer.add_scalar(title, value, iteration) def write_video(self, title, value, iteration) -> None: self.writer.add_video(title, value, global_step=iteration, fps=self.config.fps) def write_image(self, title, value, iteration) -> None: self.writer.add_image(title, value, global_step=iteration, dataformats='CHW') def write_histogram(self, title, value, iteration) -> None: self.writer.add_histogram(title, value, iteration) def write_embedding(self, all_embeddings, metadata, images) -> None: self.writer.add_embedding(all_embeddings, metadata=metadata, label_img=images) def write_embedding_no_labels(self, all_embeddings, images) -> None: self.writer.add_embedding(all_embeddings, label_img=images)
class TensorboardWriter: def __init__(self, log_dir: str, *args: Any, **kwargs: Any): r"""A Wrapper for tensorboard SummaryWriter. It creates a dummy writer when log_dir is empty string or None. It also has functionality that generates tb video directly from numpy images. Args: log_dir: Save directory location. Will not write to disk if log_dir is an empty string. *args: Additional positional args for SummaryWriter **kwargs: Additional keyword args for SummaryWriter """ self.writer = None if log_dir is not None and len(log_dir) > 0: self.writer = SummaryWriter(log_dir, *args, **kwargs) def __getattr__(self, item): if self.writer: return self.writer.__getattribute__(item) else: return lambda *args, **kwargs: None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self.writer: self.writer.close() def add_video_from_np_images(self, video_name: str, step_idx: int, images: np.ndarray, fps: int = 10) -> None: r"""Write video into tensorboard from images frames. Args: video_name: name of video string. step_idx: int of checkpoint index to be displayed. images: list of n frames. Each frame is a np.ndarray of shape. fps: frame per second for output video. Returns: None. """ if not self.writer: return # initial shape of np.ndarray list: N * (H, W, 3) frame_tensors = [ torch.from_numpy(np_arr).unsqueeze(0) for np_arr in images ] video_tensor = torch.cat(tuple(frame_tensors)) video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0) # final shape of video tensor: (1, n, 3, H, W) self.writer.add_video(video_name, video_tensor, fps=fps, global_step=step_idx)
class Logger(object): def __init__(self, log_dir, use_tb=False, config='rl'): self._log_dir = log_dir if use_tb: tb_dir = os.path.join(log_dir, 'tb') if os.path.exists(tb_dir): shutil.rmtree(tb_dir) self._sw = SummaryWriter(tb_dir) else: self._sw = None self._train_mg = MetersGroup(os.path.join(log_dir, 'train.log'), formating=FORMAT_CONFIG[config]['train']) self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval.log'), formating=FORMAT_CONFIG[config]['eval']) def _try_sw_log(self, key, value, step): if self._sw is not None: self._sw.add_scalar(key, value, step) def _try_sw_log_video(self, key, frames, step): if self._sw is not None: frames = torch.from_numpy(np.array(frames)) frames = frames.unsqueeze(0) self._sw.add_video(key, frames, step, fps=30) def _try_sw_log_histogram(self, key, histogram, step): if self._sw is not None: self._sw.add_histogram(key, histogram, step) def log(self, key, value, step, n=1): assert key.startswith('train') or key.startswith('eval') if type(value) == torch.Tensor: value = value.item() self._try_sw_log(key, value / n, step) mg = self._train_mg if key.startswith('train') else self._eval_mg mg.log(key, value, n) def log_param(self, key, param, step): self.log_histogram(key + '_w', param.weight.data, step) if hasattr(param.weight, 'grad') and param.weight.grad is not None: self.log_histogram(key + '_w_g', param.weight.grad.data, step) if hasattr(param, 'bias'): self.log_histogram(key + '_b', param.bias.data, step) if hasattr(param.bias, 'grad') and param.bias.grad is not None: self.log_histogram(key + '_b_g', param.bias.grad.data, step) def log_video(self, key, frames, step): assert key.startswith('train') or key.startswith('eval') self._try_sw_log_video(key, frames, step) def log_histogram(self, key, histogram, step): assert key.startswith('train') or key.startswith('eval') self._try_sw_log_histogram(key, histogram, step) def dump(self, step): self._train_mg.dump(step, 'train') self._eval_mg.dump(step, 'eval')
class TensorBoardOutputFormat(KVWriter): def __init__(self, folder: str): """ Dumps key/value pairs into TensorBoard's numeric format. :param folder: the folder to write the log to """ assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so" self.writer = SummaryWriter(log_dir=folder) def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): if excluded is not None and "tensorboard" in excluded: continue if isinstance(value, np.ScalarType): if isinstance(value, str): # str is considered a np.ScalarType self.writer.add_text(key, value, step) else: self.writer.add_scalar(key, value, step) if isinstance(value, th.Tensor): self.writer.add_histogram(key, value, step) if isinstance(value, Video): self.writer.add_video(key, value.frames, step, value.fps) if isinstance(value, Figure): self.writer.add_figure(key, value.figure, step, close=value.close) if isinstance(value, Image): self.writer.add_image(key, value.image, step, dataformats=value.dataformats) # Flush the output to the file self.writer.flush() def close(self) -> None: """ closes the file """ if self.writer: self.writer.close() self.writer = None
def main(args): # write into tensorboard log_path = os.path.join('demos', args.dataset + '/log') vid_path = os.path.join('demos', args.dataset + '/vids') os.makedirs(log_path, exist_ok=True) os.makedirs(vid_path, exist_ok=True) writer = SummaryWriter(log_path) device = torch.device("cuda:0") G = Generator(args.dim_z, args.dim_a, args.nclasses, args.ch).to(device) G = nn.DataParallel(G) G.load_state_dict(torch.load(args.model_path)) transform = torchvision.transforms.Compose([ transforms.Resize((args.img_size, args.img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) dataset = MUG_test(args.data_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True) with torch.no_grad(): G.eval() img = next(iter(dataloader)) bs = img.size(0) nclasses = args.nclasses z = torch.randn(bs, args.dim_z).to(device) for i in range(nclasses): y = torch.zeros(bs, nclasses).to(device) y[:, i] = 1.0 vid_gen = G(img, z, y) vid_gen = vid_gen.transpose(2, 1) vid_gen = ((vid_gen - vid_gen.min()) / (vid_gen.max() - vid_gen.min())).data writer.add_video(tag='vid_cat_%d' % i, vid_tensor=vid_gen) writer.flush() # save videos print('==> saving videos') save_videos(vid_path, vid_gen, bs, i)
class TensorLogger(object): # creating file in given logdir ... defaults to ./runs/ def __init__(self, _logdir='./runs/'): if not os.path.exists(_logdir): os.makedirs(_logdir) self.writer = SummaryWriter(log_dir=_logdir) # adding scalar value to tb file def scalar_summary(self, _tag, _value, _step): self.writer.add_scalar(_tag, _value, _step) # adding image value to tb file def image_summary(self, _tag, _image, _step, _format='CHW'): """ default dataformat for image tensor is (3, H, W) can be changed to : (1, H, W) - dataformat = CHW : (H, W, 3) - dataformat HWC : (H, W) - datformat HW """ # self.writer.add_image(_tag, _image, _step, dataformat=_format) # adding matplotlib figure to tb file def figure_summary(self, _tag, _figure, _step): self.writer.add_figure(_tag, _figure, _step) # adding video to tb file def video_summary(self, _tag, _video, _step, _fps=4): """ default torch fps is 4, can be changed also, video tensor should be of format (N, T, C, H, W) values should be between [0,255] for unit8 and [0,1] for float32 """ # default value of video fps is 4 - can be changed self.writer.add_video(_tag, _video, _step, _fps) # adding audio to tb file def audio_summary(self, _tag, _sound, _step, _sampleRate=44100): """ default torch sample rate is 44100, can be changed also, sound tensor should be of format (1, L) values should lie between [-1,1] """ self.writer.add_audio(_tag, _sound, _step, sample_rate=_sampleRate) # adding text to tb file def text_summary(self, _tag, _textString, _step): self.writer.add_text(_tag, _textString, _step) # adding histograms to tb file def histogram_summary(self, _tag, _histogram, _step, _bins='tensorflow'): self.writer.add_histrogram(_tag, _histogram, _step, bins=_bins)
class VisualizerTensorboard: def __init__(self, opts): self.dtype = {} self.iteration = 1 self.writer = SummaryWriter(opts.logs_dir) def register(self, modules): # here modules are assumed to be a dictionary for key in modules: self.dtype[key] = modules[key]['dtype'] def update(self, modules): for key, value in modules: if self.dtype[key] == 'scalar': self.writer.add_scalar(key, value, self.iteration) elif self.dtype[key] == 'scalars': self.writer.add_scalars(key, value, self.iteration) elif self.dtype[key] == 'histogram': self.writer.add_histogram(key, value, self.iteration) elif self.dtype[key] == 'image': self.writer.add_image(key, value, self.iteration) elif self.dtype[key] == 'images': self.writer.add_images(key, value, self.iteration) elif self.dtype[key] == 'figure': self.writer.add_figure(key, value, self.iteration) elif self.dtype[key] == 'video': self.writer.add_video(key, value, self.iteration) elif self.dtype[key] == 'audio': self.writer.add_audio(key, value, self.iteration) elif self.dtype[key] == 'text': self.writer.add_text(key, value, self.iteration) elif self.dtype[key] == 'embedding': self.writer.add_embedding(key, value, self.iteration) elif self.dtype[key] == 'pr_curve': self.writer.pr_curve(key, value['labels'], value['predictions'], self.iteration) elif self.dtype[key] == 'mesh': self.writer.add_audio(key, value, self.iteration) elif self.dtype[key] == 'hparams': self.writer.add_hparams(key, value['hparam_dict'], value['metric_dict'], self.iteration) else: raise Exception( 'Data type not supported, please update the visualizer plugin and rerun !!' ) self.iteration = self.iteration + 1
def main(): args = cfg.parse_args() # write into tensorboard log_path = os.path.join(args.demo_path, args.demo_name + '/log') vid_path = os.path.join(args.demo_path, args.demo_name + '/vids') if not os.path.exists(log_path) and not os.path.exists(vid_path): os.makedirs(log_path) os.makedirs(vid_path) writer = SummaryWriter(log_path) device = torch.device("cuda:0") G = Generator().to(device) G = nn.DataParallel(G) G.load_state_dict(torch.load(args.model_path)) with torch.no_grad(): G.eval() za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) zm = torch.randn(args.n_zm_test, args.d_zm, 1, 1, 1).to(device) n_za = za.size(0) n_zm = zm.size(0) za = za.unsqueeze(1).repeat(1, n_zm, 1, 1, 1, 1).contiguous().view( n_za * n_zm, -1, 1, 1, 1) zm = zm.repeat(n_za, 1, 1, 1, 1) vid_fake = G(za, zm) vid_fake = vid_fake.transpose(2, 1) # bs x 16 x 3 x 64 x 64 vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data writer.add_video(tag='generated_videos', global_step=1, vid_tensor=vid_fake) writer.flush() # save into videos print('==> saving videos...') save_videos(vid_path, vid_fake, n_za, n_zm) return
class TensorboardSummary(object): def __init__(self, directory, neptune_exp=None): self.directory = directory self.writer = SummaryWriter(log_dir=os.path.join(self.directory)) self.neptune_exp = neptune_exp self.to_image = transforms.ToPILImage() def add_scalar(self, log_name, value, index): if self.neptune_exp: self.neptune_exp.log_metric(log_name, index, value) else: self.writer.add_scalar(log_name, value, index) def visualize_video(self, opt, global_step, video, name): video_transpose = video.permute(0, 2, 1, 3, 4) # BxTxCxHxW video_reshaped = video_transpose.flatten(0, 1) # (B+T)xCxHxW # image_range = opt.td #+ opt.num_targets image_range = video.shape[2] grid_image = make_grid( video_reshaped[:3 * image_range, :, :, :].clone().cpu().data, image_range, normalize=True) self.writer.add_image( 'Video/Scale {}/{}_unfold'.format(opt.scale_idx, name), grid_image, global_step) norm_range(video_transpose) self.writer.add_video('Video/Scale {}/{}'.format(opt.scale_idx, name), video_transpose[:3], global_step) def visualize_image(self, opt, global_step, ןimages, name): grid_image = make_grid(ןimages[:3, :, :, :].clone().cpu().data, 3, normalize=True) img_name = 'Image/Scale {}/{}'.format(opt.scale_idx, name) if self.neptune_exp: self.neptune_exp.log_image(img_name, global_step, y=self.to_image(grid_image)) else: self.writer.add_image(img_name, grid_image, global_step)
def train_vis(self, model, dataset, writer: SummaryWriter, global_step, indices, device, cond_steps, fg_sample, bg_sample, num_gen): grid, gif = self.show_tracking(model, dataset, indices, device) writer.add_image('tracking/grid', grid, global_step) for i in range(len(gif)): writer.add_video(f'tracking/video_{i}', gif[i:i + 1], global_step) # Generation grid, gif = self.show_generation(model, dataset, indices, device, cond_steps, fg_sample=fg_sample, bg_sample=bg_sample, num=num_gen) writer.add_image('generation/grid', grid, global_step) for i in range(len(gif)): writer.add_video(f'generation/video_{i}', gif[i:i + 1], global_step)
def train(self, lens, texts, movies, pp=0, lp=20): """ lp: log period pp: print period (0 - no print to stdout) """ writer = SummaryWriter() writer.add_video("Real Clips", to_video(movies)) device = next(self.generator.parameters()).device time_per_epoch = -time.time() for epoch in range(self.num_epochs): for No, batch in enumerate(self.vloader): labels = batch['label'].to(device, non_blocking=True) videos = batch['video'].to(device, non_blocking=True) senlen = batch['slens'].to(device, non_blocking=True) self.passBatchThroughNetwork(labels, videos, senlen) if pp and epoch % pp == 0: time_per_epoch += time.time() print(f'Epoch {epoch}/{self.num_epochs}') for k, v in self.logs.items(): print("\t%s:\t%5.4f" % (k, v / (No + 1))) self.logs[k] = v / (No + 1) print('Completed in %.f s' % time_per_epoch) time_per_epoch = -time.time() if epoch % lp == 0: self.generator.eval() with torch.no_grad(): condition = self.encoder(texts, lens) movies = self.generator(condition) writer.add_scalars('Loss', self.logs, epoch) writer.add_video('Fakes', to_video(movies), epoch) self.generator.train() self.logs = dict.fromkeys(self.logs, 0) torch.save(self.generator.state_dict(), self.log_folder / ('gen_%05d.pytorch' % epoch)) print('Training has been completed successfully!')
def main(): args = cfg.parse_args() # write into tensorboard log_path = os.path.join(args.demo_path, args.demo_name + '/log') vid_path = os.path.join(args.demo_path, args.demo_name + '/vids') if not os.path.exists(log_path) and not os.path.exists(vid_path): os.makedirs(log_path) os.makedirs(vid_path) writer = SummaryWriter(log_path) device = torch.device("cuda:0") G = Generator().to(device) G = nn.DataParallel(G) G.load_state_dict(torch.load(args.model_path)) with torch.no_grad(): G.eval() za = torch.randn(args.n_za_test, args.d_za, 1, 1, 1).to(device) # appearance # generating frames from [16, 20, 24, 28, 32, 36, 40, 44, 48] for i in range(9): zm = torch.randn(args.n_zm_test, args.d_zm, (i+1), 1, 1).to(device) # 16+i*4 vid_fake = G(za, zm) vid_fake = vid_fake.transpose(2,1) vid_fake = ((vid_fake - vid_fake.min()) / (vid_fake.max() - vid_fake.min())).data writer.add_video(tag='generated_videos_%dframes'%(16+i*4), global_step=1, vid_tensor=vid_fake) writer.flush() print('saving videos') save_videos(vid_path, vid_fake, args.n_za_test, (16+i*4)) return
class Trainer(): def __init__( self, data_loaders, generator, gen_optimizer, end_epoch, criterion, start_epoch=0, lr_scheduler=None, device=None, writer=None, debug=False, debug_freq=1000, logdir='output', resume=None, performance_type='min', num_iters_per_epoch=1000, ): # Prepare dataloaders self.train_2d_loader, self.train_3d_loader, self.valid_loader = data_loaders self.train_2d_iter = self.train_3d_iter = None if self.train_2d_loader: self.train_2d_iter = iter(self.train_2d_loader) if self.train_3d_loader: self.train_3d_iter = iter(self.train_3d_loader) # Models and optimizers self.generator = generator self.gen_optimizer = gen_optimizer # Training parameters self.start_epoch = start_epoch self.end_epoch = end_epoch self.criterion = criterion self.lr_scheduler = lr_scheduler self.device = device self.writer = writer self.debug = debug self.debug_freq = debug_freq self.logdir = logdir self.performance_type = performance_type self.train_global_step = 0 self.valid_global_step = 0 self.epoch = 0 self.best_performance = float( 'inf') if performance_type == 'min' else -float('inf') self.evaluation_accumulators = dict.fromkeys( ['pred_j3d', 'target_j3d', 'target_theta', 'pred_verts']) self.num_iters_per_epoch = num_iters_per_epoch if self.writer is None: from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.logdir) if self.device is None: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # Resume from a pretrained model if resume is not None: self.resume_pretrained(resume) def train(self): # Single epoch training routine losses = AverageMeter() timer = { 'data': 0, 'forward': 0, 'loss': 0, 'backward': 0, 'batch': 0, } self.generator.train() start = time.time() summary_string = '' # bar = Bar(f'Epoch {self.epoch + 1}/{self.end_epoch}', fill='#', max=self.num_iters_per_epoch) pbar = tqdm(range(self.num_iters_per_epoch)) for i in pbar: # Dirty solution to reset an iterator target_2d = target_3d = None if self.train_2d_iter: try: target_2d = next(self.train_2d_iter) except StopIteration: self.train_2d_iter = iter(self.train_2d_loader) target_2d = next(self.train_2d_iter) move_dict_to_device(target_2d, self.device) if self.train_3d_iter: try: target_3d = next(self.train_3d_iter) except StopIteration: self.train_3d_iter = iter(self.train_3d_loader) target_3d = next(self.train_3d_iter) move_dict_to_device(target_3d, self.device) # <======= Feedforward generator and discriminator if target_2d and target_3d: inp = torch.cat((target_2d['features'], target_3d['features']), dim=0).to(self.device) elif target_3d: inp = target_3d['features'].to(self.device) else: inp = target_2d['features'].to(self.device) timer['data'] = time.time() - start start = time.time() preds = self.generator(inp) timer['forward'] = time.time() - start start = time.time() gen_loss, loss_dict = self.criterion( generator_outputs=preds, data_2d=target_2d, data_3d=target_3d, ) # =======> timer['loss'] = time.time() - start start = time.time() # <======= Backprop generator and discriminator self.gen_optimizer.zero_grad() gen_loss.backward() self.gen_optimizer.step() # if self.train_global_step % self.dis_motion_update_steps == 0: # self.dis_motion_optimizer.zero_grad() # motion_dis_loss.backward() # self.dis_motion_optimizer.step() # =======> # <======= Log training info # total_loss = gen_loss + motion_dis_loss total_loss = gen_loss losses.update(total_loss.item(), inp.size(0)) timer['backward'] = time.time() - start timer['batch'] = timer['data'] + timer['forward'] + timer[ 'loss'] + timer['backward'] start = time.time() # summary_string = f'({i + 1}/{self.num_iters_per_epoch}) | Total: {bar.elapsed_td} | ' \ # f'ETA: {bar.eta_td:} | loss: {losses.avg:.4f}' summary_string = '| loss: {:.4f}'.format(losses.avg) for k, v in loss_dict.items(): summary_string += f' | {k}: {v:.2f}' self.writer.add_scalar('train_loss/' + k, v, global_step=self.train_global_step) # for k,v in timer.items(): # summary_string += ' | {}: {:.2f}'.format(k, v) self.writer.add_scalar('train_loss/loss', total_loss.item(), global_step=self.train_global_step) if self.debug: print('==== Visualize ====') from psypose.MEVA.meva.utils.vis import batch_visualize_vid_preds video = target_3d['video'] dataset = 'spin' vid_tensor = batch_visualize_vid_preds(video, preds[-1], target_3d.copy(), vis_hmr=False, dataset=dataset) self.writer.add_video('train-video', vid_tensor, global_step=self.train_global_step, fps=10) self.train_global_step += 1 # bar.suffix = summary_string # bar.next() pbar.set_description(summary_string) if torch.isnan(total_loss): exit('Nan value in loss, exiting!...') # =======> # bar.finish() logger.info(summary_string) def validate(self): self.generator.eval() start = time.time() summary_string = '' bar = Bar('Validation', fill='#', max=len(self.valid_loader)) if self.evaluation_accumulators is not None: for k, v in self.evaluation_accumulators.items(): self.evaluation_accumulators[k] = [] J_regressor = torch.from_numpy( np.load(osp.join(MEVA_DATA_DIR, 'J_regressor_h36m.npy'))).float() for i, target in enumerate(self.valid_loader): move_dict_to_device(target, self.device) # <============= with torch.no_grad(): inp = target['features'] preds = self.generator(inp, J_regressor=J_regressor) # convert to 14 keypoint format for evaluation n_kp = preds[-1]['kp_3d'].shape[-2] pred_j3d = preds[-1]['kp_3d'].view(-1, n_kp, 3).cpu().numpy() target_j3d = target['kp_3d'].view(-1, n_kp, 3).cpu().numpy() pred_verts = preds[-1]['verts'].view(-1, 6890, 3).cpu().numpy() target_theta = target['theta'].view(-1, 85).cpu().numpy() self.evaluation_accumulators['pred_verts'].append(pred_verts) self.evaluation_accumulators['target_theta'].append( target_theta) self.evaluation_accumulators['pred_j3d'].append(pred_j3d) self.evaluation_accumulators['target_j3d'].append(target_j3d) # =============> # <============= DEBUG if self.debug and self.valid_global_step % self.debug_freq == 0: from psypose.MEVA.meva.utils.vis import batch_visualize_vid_preds video = target['video'] dataset = 'common' vid_tensor = batch_visualize_vid_preds(video, preds[-1], target, vis_hmr=False, dataset=dataset) self.writer.add_video('valid-video', vid_tensor, global_step=self.valid_global_step, fps=10) # =============> batch_time = time.time() - start summary_string = f'({i + 1}/{len(self.valid_loader)}) | batch: {batch_time * 10.0:.4}ms | ' \ f'Total: {bar.elapsed_td} | ETA: {bar.eta_td:}' self.valid_global_step += 1 bar.suffix = summary_string bar.next() bar.finish() logger.info(summary_string) def fit(self): for epoch in range(self.start_epoch, self.end_epoch): self.epoch = epoch self.train() self.validate() performance = self.evaluate() if self.lr_scheduler is not None: self.lr_scheduler.step(performance) # log the learning rate for param_group in self.gen_optimizer.param_groups: print(f'Learning rate {param_group["lr"]}') self.writer.add_scalar('lr/gen_lr', param_group['lr'], global_step=self.epoch) # for param_group in self.dis_motion_optimizer.param_groups: # print(f'Learning rate {param_group["lr"]}') # self.writer.add_scalar('lr/dis_lr', param_group['lr'], global_step=self.epoch) logger.info(f'Epoch {epoch+1} performance: {performance:.4f}') self.save_model(performance, epoch) # if performance > 200.0: # exit(f'MPJPE error is {performance}, higher than 80.0. Exiting!...') self.writer.close() def save_model(self, performance, epoch): save_dict = { 'epoch': epoch, 'gen_state_dict': self.generator.state_dict(), 'performance': performance, 'gen_optimizer': self.gen_optimizer.state_dict(), # 'disc_motion_state_dict': self.motion_discriminator.state_dict(), # 'disc_motion_optimizer': self.dis_motion_optimizer.state_dict(), } filename = osp.join(self.logdir, 'checkpoint.pth.tar') torch.save(save_dict, filename) if self.performance_type == 'min': is_best = performance < self.best_performance else: is_best = performance > self.best_performance if is_best: logger.info('Best performance achived, saving it!') self.best_performance = performance shutil.copyfile(filename, osp.join(self.logdir, 'model_best.pth.tar')) with open(osp.join(self.logdir, 'best.txt'), 'w') as f: f.write(str(float(performance))) def resume_pretrained(self, model_path): if osp.isfile(model_path): checkpoint = torch.load(model_path) self.start_epoch = checkpoint['epoch'] self.generator.load_state_dict(checkpoint['gen_state_dict']) self.gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) self.best_performance = checkpoint['performance'] # if 'disc_motion_optimizer' in checkpoint.keys(): # self.motion_discriminator.load_state_dict(checkpoint['disc_motion_state_dict']) # self.dis_motion_optimizer.load_state_dict(checkpoint['disc_motion_optimizer']) logger.info( f"=> loaded checkpoint '{model_path}' " f"(epoch {self.start_epoch}, performance {self.best_performance})" ) else: logger.info(f"=> no checkpoint found at '{model_path}'") def evaluate(self): for k, v in self.evaluation_accumulators.items(): self.evaluation_accumulators[k] = np.vstack(v) pred_j3ds = self.evaluation_accumulators['pred_j3d'] target_j3ds = self.evaluation_accumulators['target_j3d'] pred_j3ds = torch.from_numpy(pred_j3ds).float() target_j3ds = torch.from_numpy(target_j3ds).float() print(f'Evaluating on {pred_j3ds.shape[0]} number of poses...') pred_pelvis = (pred_j3ds[:, [2], :] + pred_j3ds[:, [3], :]) / 2.0 target_pelvis = (target_j3ds[:, [2], :] + target_j3ds[:, [3], :]) / 2.0 pred_j3ds -= pred_pelvis target_j3ds -= target_pelvis # Absolute error (MPJPE) errors = torch.sqrt( ((pred_j3ds - target_j3ds)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy() S1_hat = batch_compute_similarity_transform_torch( pred_j3ds, target_j3ds) errors_pa = torch.sqrt( ((S1_hat - target_j3ds)**2).sum(dim=-1)).mean(dim=-1).cpu().numpy() pred_verts = self.evaluation_accumulators['pred_verts'] target_theta = self.evaluation_accumulators['target_theta'] m2mm = 1000 pve = np.mean( compute_error_verts(target_theta=target_theta, pred_verts=pred_verts)) * m2mm accel = np.mean(compute_accel(pred_j3ds)) * m2mm accel_err = np.mean( compute_error_accel(joints_pred=pred_j3ds, joints_gt=target_j3ds)) * m2mm mpjpe = np.mean(errors) * m2mm pa_mpjpe = np.mean(errors_pa) * m2mm eval_dict = { 'mpjpe': mpjpe, 'pa-mpjpe': pa_mpjpe, 'accel': accel, 'pve': pve, 'accel_err': accel_err } log_str = f'Epoch {self.epoch}, ' log_str += ' '.join( [f'{k.upper()}: {v:.4f},' for k, v in eval_dict.items()]) logger.info(log_str) for k, v in eval_dict.items(): self.writer.add_scalar(f'error/{k}', v, global_step=self.epoch) return pa_mpjpe
class Logger(object): def __init__(self, log_dir, comment=''): self.writer = SummaryWriter(log_dir=log_dir, comment=comment) self.imgs_dict = {} def scalar_summary(self, tag, value, step): self.writer.add_scalar(tag, value, global_step=step) self.writer.flush() def combined_scalars_summary(self, main_tag, tag_scalar_dict, step): self.writer.add_scalars(main_tag, tag_scalar_dict, step) self.writer.flush() def log(self, tag, text_string, step=0): self.writer.add_text(tag, text_string, step) self.writer.flush() def log_model(self, model, inputs): self.writer.add_graph(model, inputs) self.writer.flush() def get_dir(self): return self.writer.get_logdir() def log_model_state(self, model, name='tmp'): path = os.path.join(self.writer.get_logdir(), type(model).__name__ + '_%s.pt' % name) torch.save(model.state_dict(), path) def log_video(self, tag, global_step=None, img_tns=None, finished_video=False, video_tns=None, debug=False): ''' Logs video to tensorboard. Video_tns will be empty. If given image tensors, then when finished_video = True, the video of the past tensors will be made into one video. If vide_tns is not empty, then that will be marked the video and the other arguments will be ignored. ''' if debug: import pdb pdb.set_trace() if img_tns is None and video_tns is None: if not finished_video or tag not in self.imgs_dict.keys(): return None lst_img_tns = self.imgs_dict[tag] self.writer.add_video(tag, torch.tensor(lst_img_tns), global_step=global_step, fps=4) self.writer.flush() self.imgs_dict[tag] = [] return None elif video_tns is not None: self.writer.add_video(tag, video_tns, global_step=global_step, fps=4) self.writer.flush() return None if tag in self.imgs_dict.keys(): lst_img_tns = self.imgs_dict[tag] else: lst_img_tns = [] self.imgs_dict[tag] = lst_img_tns lst_img_tns.append(img_tns) if finished_video: self.writer.add_video(tag, torch.tensor(lst_img_tns), global_step=global_step, fps=4) self.writer.flush() self.imgs_dict[tag].clear() def close(self): self.writer.close()
class Logger(object): def __init__(self, log_dir, save_tb=False, log_frequency=10000, agent="sac"): self._log_dir = log_dir self._log_frequency = log_frequency if save_tb: tb_dir = os.path.join(log_dir, "tb") if os.path.exists(tb_dir): try: shutil.rmtree(tb_dir) except: print("logger.py warning: Unable to remove tb directory") pass self._sw = SummaryWriter(tb_dir) else: self._sw = None # each agent has specific output format for training assert agent in AGENT_TRAIN_FORMAT train_format = COMMON_TRAIN_FORMAT + AGENT_TRAIN_FORMAT[agent] self._train_mg = MetersGroup(os.path.join(log_dir, "train"), formating=train_format) self._eval_mg = MetersGroup(os.path.join(log_dir, "eval"), formating=COMMON_EVAL_FORMAT) def _should_log(self, step, log_frequency): log_frequency = log_frequency or self._log_frequency return step % log_frequency == 0 def _try_sw_log(self, key, value, step): if self._sw is not None: self._sw.add_scalar(key, value, step) def _try_sw_log_video(self, key, frames, step): if self._sw is not None: frames = torch.from_numpy(np.array(frames)) frames = frames.unsqueeze(0) self._sw.add_video(key, frames, step, fps=30) def _try_sw_log_histogram(self, key, histogram, step): if self._sw is not None: self._sw.add_histogram(key, histogram, step) def log(self, key, value, step, n=1, log_frequency=1): if not self._should_log(step, log_frequency): return assert key.startswith("train") or key.startswith("eval") if type(value) == torch.Tensor: value = value.item() self._try_sw_log(key, value / n, step) mg = self._train_mg if key.startswith("train") else self._eval_mg mg.log(key, value, n) def log_param(self, key, param, step, log_frequency=None): if not self._should_log(step, log_frequency): return self.log_histogram(key + "_w", param.weight.data, step) if hasattr(param.weight, "grad") and param.weight.grad is not None: self.log_histogram(key + "_w_g", param.weight.grad.data, step) if hasattr(param, "bias") and hasattr(param.bias, "data"): self.log_histogram(key + "_b", param.bias.data, step) if hasattr(param.bias, "grad") and param.bias.grad is not None: self.log_histogram(key + "_b_g", param.bias.grad.data, step) def log_video(self, key, frames, step, log_frequency=None): if not self._should_log(step, log_frequency): return assert key.startswith("train") or key.startswith("eval") self._try_sw_log_video(key, frames, step) def log_histogram(self, key, histogram, step, log_frequency=None): if not self._should_log(step, log_frequency): return assert key.startswith("train") or key.startswith("eval") self._try_sw_log_histogram(key, histogram, step) def dump(self, step, save=True, ty=None): if ty is None: self._train_mg.dump(step, "train", save) self._eval_mg.dump(step, "eval", save) elif ty == "eval": self._eval_mg.dump(step, "eval", save) elif ty == "train": self._train_mg.dump(step, "train", save) else: raise f"invalid log type: {ty}"
class A2CTrial(PyTorchTrial): def __init__(self, trial_context: PyTorchTrialContext) -> None: self.context = trial_context self.download_directory = f"/tmp/data-rank{self.context.distributed.get_rank()}" # self.logger = TorchWriter() self.n_stack = self.context.get_hparam("n_stack") self.env_name = self.context.get_hparam("env_name") self.num_envs = self.context.get_hparam("num_envs") self.rollout_size = self.context.get_hparam("rollout_size") self.curiousity = self.context.get_hparam("curiousity") self.lr = self.context.get_hparam("lr") self.icm_beta = self.context.get_hparam("icm_beta") self.value_coeff = self.context.get_hparam("value_coeff") self.entropy_coeff = self.context.get_hparam("entropy_coeff") self.max_grad_norm = self.context.get_hparam("max_grad_norm") env = make_atari_env(self.env_name, num_env=self.num_envs, seed=42) self.env = VecFrameStack(env, n_stack=self.n_stack) eval_env = make_atari_env(self.env_name, num_env=1, seed=42) self.eval_env = VecFrameStack(eval_env, n_stack=self.n_stack) # constants self.in_size = self.context.get_hparam("in_size") # in_size self.num_actions = env.action_space.n def init_(m): return init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0)) self.feat_enc_net = self.context.Model( FeatureEncoderNet(self.n_stack, self.in_size)) self.actor = self.context.Model( init_(nn.Linear(self.feat_enc_net.hidden_size, self.num_actions))) self.critic = self.context.Model( init_(nn.Linear(self.feat_enc_net.hidden_size, 1))) self.set_recurrent_buffers(self.num_envs) params = list(self.feat_enc_net.parameters()) + list( self.actor.parameters()) + list(self.critic.parameters()) self.opt = self.context.Optimizer(torch.optim.Adam(params, self.lr)) self.is_cuda = torch.cuda.is_available() self.storage = RolloutStorage(self.rollout_size, self.num_envs, self.env.observation_space.shape[0:-1], self.n_stack, is_cuda=self.is_cuda, value_coeff=self.value_coeff, entropy_coeff=self.entropy_coeff) obs = self.env.reset() self.storage.states[0].copy_(self.storage.obs2tensor(obs)) self.writer = SummaryWriter(log_dir="/tmp/tensorboard") self.global_eval_count = 0 def set_recurrent_buffers(self, buf_size): self.feat_enc_net.reset_lstm(buf_size=buf_size) def reset_recurrent_buffers(self, reset_indices): self.feat_enc_net.reset_lstm(reset_indices=reset_indices) def build_training_data_loader(self) -> DataLoader: ds = torchvision.datasets.MNIST( self.download_directory, train=True, transform=transforms.Compose([ transforms.ToTensor(), # These are the precomputed mean and standard deviation of the # MNIST data; this normalizes the data to have zero mean and unit # standard deviation. transforms.Normalize((0.1307, ), (0.3081, )), ]), download=True) return DataLoader(ds, batch_size=1) def build_validation_data_loader(self) -> DataLoader: ds = torchvision.datasets.MNIST( self.download_directory, train=False, transform=transforms.Compose([ transforms.ToTensor(), # These are the precomputed mean and standard deviation of the # MNIST data; this normalizes the data to have zero mean and unit # standard deviation. transforms.Normalize((0.1307, ), (0.3081, )), ]), download=True) return DataLoader(ds, batch_size=1) def train_batch(self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int) -> Dict[str, torch.Tensor]: final_value, entropy = self.episode_rollout() self.opt.zero_grad() total_loss, value_loss, policy_loss, entropy_loss = self.storage.a2c_loss( final_value, entropy) self.context.backward(total_loss) def clip_grads(parameters): torch.nn.utils.clip_grad_norm_(parameters, self.max_grad_norm) self.context.step_optimizer(self.opt, clip_grads) self.storage.after_update() return { 'loss': total_loss, 'value_loss': value_loss, 'policy_loss': policy_loss, 'entropy_loss': entropy_loss } def get_action(self, state, deterministic=False): feature = self.feat_enc_net(state) # calculate policy and value function policy = self.actor(feature) value = torch.squeeze(self.critic(feature)) action_prob = F.softmax(policy, dim=-1) cat = Categorical(action_prob) if not deterministic: action = cat.sample() return (action, cat.log_prob(action), cat.entropy().mean(), value, feature) else: action = np.argmax(action_prob.detach().cpu().numpy(), axis=1) return (action, [], [], value, feature) def episode_rollout(self): episode_entropy = 0 for step in range(self.rollout_size): """Interact with the environments """ # call A2C a_t, log_p_a_t, entropy, value, a2c_features = self.get_action( self.storage.get_state(step)) # accumulate episode entropy episode_entropy += entropy # interact obs, rewards, dones, infos = self.env.step(a_t.cpu().numpy()) # save episode reward self.storage.log_episode_rewards(infos) self.storage.insert(step, rewards, obs, a_t, log_p_a_t, value, dones) self.reset_recurrent_buffers(reset_indices=dones) # Note: # get the estimate of the final reward # that's why we have the CRITIC --> estimate final reward # detach, as the final value will only be used as a with torch.no_grad(): state = self.storage.get_state(step + 1) final_features = self.feat_enc_net(state) final_value = torch.squeeze(self.critic(final_features)) return final_value, episode_entropy def evaluate_full_dataset(self, data_loader, model) -> Dict[str, Any]: self.global_eval_count += 1 episode_rewards, episode_lengths = [], [] n_eval_episodes = 10 self.set_recurrent_buffers(1) frames = [] with torch.no_grad(): for episode in range(n_eval_episodes): obs = self.eval_env.reset() done, state = False, None episode_reward = 0.0 episode_length = 0 while not done: state = self.storage.obs2tensor(obs) if episode == 0: frame = torch.unsqueeze(torch.squeeze(state)[0], 0).detach() frames.append(frame) action, _, _, _, _ = self.get_action(state, deterministic=True) obs, reward, done, _info = self.eval_env.step(action) reward = reward[0] done = done[0] episode_reward += reward episode_length += 1 if episode == 0: video = torch.unsqueeze(torch.stack(frames), 0) self.writer.add_video('policy', video, global_step=self.global_eval_count, fps=20) episode_rewards.append(episode_reward) episode_lengths.append(episode_length) mean_reward = np.mean(episode_rewards) std_reward = np.std(episode_rewards) self.set_recurrent_buffers(self.num_envs) return {'mean_reward': mean_reward}
class TensorboardWriter(object): """ Helper class to log information to Tensorboard. """ def __init__(self, cfg): """ Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py """ # class_names: list of class names. # cm_subset_classes: a list of class ids -- a user-specified subset. # parent_map: dictionary where key is the parent class name and # value is a list of ids of its children classes. # hist_subset_classes: a list of class ids -- user-specified to plot histograms. ( self.class_names, self.cm_subset_classes, self.parent_map, self.hist_subset_classes, ) = (None, None, None, None) self.cfg = cfg self.cm_figsize = cfg.TENSORBOARD.CONFUSION_MATRIX.FIGSIZE self.hist_figsize = cfg.TENSORBOARD.HISTOGRAM.FIGSIZE if cfg.TENSORBOARD.LOG_DIR == "": log_dir = os.path.join(cfg.OUTPUT_DIR, "runs-{}".format(cfg.TRAIN.DATASET)) else: log_dir = os.path.join(cfg.OUTPUT_DIR, cfg.TENSORBOARD.LOG_DIR) self.writer = SummaryWriter(log_dir=log_dir) logger.info( "To see logged results in Tensorboard, please launch using the command \ `tensorboard --port=<port-number> --logdir {}`".format(log_dir)) if cfg.TENSORBOARD.CLASS_NAMES_PATH != "": if cfg.DETECTION.ENABLE: logger.info("Plotting confusion matrix is currently \ not supported for detection.") ( self.class_names, self.parent_map, self.cm_subset_classes, ) = get_class_names( cfg.TENSORBOARD.CLASS_NAMES_PATH, cfg.TENSORBOARD.CATEGORIES_PATH, cfg.TENSORBOARD.CONFUSION_MATRIX.SUBSET_PATH, ) if cfg.TENSORBOARD.HISTOGRAM.ENABLE: if cfg.DETECTION.ENABLE: logger.info("Plotting histogram is not currently \ supported for detection tasks.") if cfg.TENSORBOARD.HISTOGRAM.SUBSET_PATH != "": _, _, self.hist_subset_classes = get_class_names( cfg.TENSORBOARD.CLASS_NAMES_PATH, None, cfg.TENSORBOARD.HISTOGRAM.SUBSET_PATH, ) def add_scalars(self, data_dict, global_step=None): """ Add multiple scalars to Tensorboard logs. Args: data_dict (dict): key is a string specifying the tag of value. global_step (Optinal[int]): Global step value to record. """ if self.writer is not None: for key, item in data_dict.items(): self.writer.add_scalar(key, item, global_step) def plot_eval(self, preds, labels, global_step=None): """ Plot confusion matrices and histograms for eval/test set. Args: preds (tensor or list of tensors): list of predictions. labels (tensor or list of tensors): list of labels. global step (Optional[int]): current step in eval/test. """ if not self.cfg.DETECTION.ENABLE: cmtx = None if self.cfg.TENSORBOARD.CONFUSION_MATRIX.ENABLE: cmtx = vis_utils.get_confusion_matrix( preds, labels, self.cfg.MODEL.NUM_CLASSES) # Add full confusion matrix. add_confusion_matrix( self.writer, cmtx, self.cfg.MODEL.NUM_CLASSES, global_step=global_step, class_names=self.class_names, figsize=self.cm_figsize, ) # If a list of subset is provided, plot confusion matrix subset. if self.cm_subset_classes is not None: add_confusion_matrix( self.writer, cmtx, self.cfg.MODEL.NUM_CLASSES, global_step=global_step, subset_ids=self.cm_subset_classes, class_names=self.class_names, tag="Confusion Matrix Subset", figsize=self.cm_figsize, ) # If a parent-child classes mapping is provided, plot confusion # matrices grouped by parent classes. if self.parent_map is not None: # Get list of tags (parent categories names) and their children. for parent_class, children_ls in self.parent_map.items(): tag = ( "Confusion Matrices Grouped by Parent Classes/" + parent_class) add_confusion_matrix( self.writer, cmtx, self.cfg.MODEL.NUM_CLASSES, global_step=global_step, subset_ids=children_ls, class_names=self.class_names, tag=tag, figsize=self.cm_figsize, ) if self.cfg.TENSORBOARD.HISTOGRAM.ENABLE: if cmtx is None: cmtx = vis_utils.get_confusion_matrix( preds, labels, self.cfg.MODEL.NUM_CLASSES) plot_hist( self.writer, cmtx, self.cfg.MODEL.NUM_CLASSES, self.cfg.TENSORBOARD.HISTOGRAM.TOPK, global_step=global_step, subset_ids=self.hist_subset_classes, class_names=self.class_names, figsize=self.hist_figsize, ) def add_video(self, vid_tensor, tag="Video Input", global_step=None, fps=4): """ Add input to tensorboard SummaryWriter as a video. Args: vid_tensor (tensor): shape of (B, T, C, H, W). Values should lie [0, 255] for type uint8 or [0, 1] for type float. tag (Optional[str]): name of the video. global_step(Optional[int]): current step. fps (int): frames per second. """ self.writer.add_video(tag, vid_tensor, global_step=global_step, fps=fps) def plot_weights_and_activations( self, weight_activation_dict, tag="", normalize=False, global_step=None, batch_idx=None, indexing_dict=None, heat_map=True, ): """ Visualize weights/ activations tensors to Tensorboard. Args: weight_activation_dict (dict[str, tensor]): a dictionary of the pair {layer_name: tensor}, where layer_name is a string and tensor is the weights/activations of the layer we want to visualize. tag (Optional[str]): name of the video. normalize (bool): If True, the tensor is normalized. (Default to False) global_step(Optional[int]): current step. batch_idx (Optional[int]): current batch index to visualize. If None, visualize the entire batch. indexing_dict (Optional[dict]): a dictionary of the {layer_name: indexing}. where indexing is numpy-like fancy indexing. heatmap (bool): whether to add heatmap to the weights/ activations. """ for name, array in weight_activation_dict.items(): if batch_idx is None: # Select all items in the batch if batch_idx is not provided. batch_idx = list(range(array.shape[0])) if indexing_dict is not None: fancy_indexing = indexing_dict[name] fancy_indexing = (batch_idx, ) + fancy_indexing array = array[fancy_indexing] else: array = array[batch_idx] add_ndim_array( self.writer, array, tag + name, normalize=normalize, global_step=global_step, heat_map=heat_map, ) def flush(self): self.writer.flush() def close(self): self.writer.flush() self.writer.close()
class Logger: _count = 0 def __init__(self, scrn=True, log_dir='', phase=''): super().__init__() self._logger = logging.getLogger('logger_{}'.format(Logger._count)) Logger._count += 1 self._logger.setLevel(logging.DEBUG) if scrn: self._scrn_handler = logging.StreamHandler() self._scrn_handler.setLevel(logging.INFO) self._scrn_handler.setFormatter( logging.Formatter(fmt=FORMAT_SHORT)) self._logger.addHandler(self._scrn_handler) if log_dir and phase: self.log_path = os.path.join( log_dir, '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format( phase, *localtime()[:6])) self.show_nl("log into {}\n\n".format(self.log_path)) self._file_handler = logging.FileHandler(filename=self.log_path) self._file_handler.setLevel(logging.DEBUG) self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG)) self._logger.addHandler(self._file_handler) self._writer = SummaryWriter(log_dir=os.path.join( log_dir, '{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}'.format( phase, *localtime()[:6]))) def show(self, *args, **kwargs): return self._logger.info(*args, **kwargs) def show_nl(self, *args, **kwargs): self._logger.info("") return self.show(*args, **kwargs) def dump(self, *args, **kwargs): return self._logger.debug(*args, **kwargs) def warning(self, *args, **kwargs): return self._logger.warning(*args, **kwargs) def error(self, *args, **kwargs): return self._logger.error(*args, **kwargs) # tensorboard def add_scalar(self, *args, **kwargs): return self._writer.add_scalar(*args, **kwargs) def add_scalars(self, *args, **kwargs): return self._writer.add_scalars(*args, **kwargs) def add_histogram(self, *args, **kwargs): return self._writer.add_histogram(*args, **kwargs) def add_image(self, *args, **kwargs): return self._writer.add_image(*args, **kwargs) def add_images(self, *args, **kwargs): return self._writer.add_images(*args, **kwargs) def add_figure(self, *args, **kwargs): return self._writer.add_figure(*args, **kwargs) def add_video(self, *args, **kwargs): return self._writer.add_video(*args, **kwargs) def add_audio(self, *args, **kwargs): return self._writer.add_audio(*args, **kwargs) def add_text(self, *args, **kwargs): return self._writer.add_text(*args, **kwargs) def add_graph(self, *args, **kwargs): return self._writer.add_graph(*args, **kwargs) def add_pr_curve(self, *args, **kwargs): return self._writer.add_pr_curve(*args, **kwargs) def add_custom_scalars(self, *args, **kwargs): return self._writer.add_custom_scalars(*args, **kwargs) def add_mesh(self, *args, **kwargs): return self._writer.add_mesh(*args, **kwargs) # def add_hparams(self, *args, **kwargs): # return self._writer.add_hparams(*args, **kwargs) def flush(self): return self._writer.flush() def close(self): return self._writer.close() def _grad_hook(self, grad, name=None, grads=None): grads.update({name: grad}) def watch_grad(self, model, layers): """ Add hooks to the specific layers. Gradients of these layers will save to self.grads :param model: :param layers: Except a list eg. layers=[0, -1] means to watch the gradients of the fist layer and the last layer of the model :return: """ assert layers if not hasattr(self, 'grads'): self.grads = {} self.grad_hooks = {} named_parameters = list(model.named_parameters()) for layer in layers: name = named_parameters[layer][0] handle = named_parameters[layer][1].register_hook( functools.partial(self._grad_hook, name=name, grads=self.grads)) self.grad_hooks.update(dict(name=handle)) def watch_grad_close(self): for _, handle in self.grad_hooks.items(): handle.remove() # remove the hook def add_grads(self, global_step=None, *args, **kwargs): """ Add gradients to tensorboard. You must call the method self.watch_grad before using this method! """ assert hasattr(self, 'grads'),\ "self.grads is nonexisent! You must call self.watch_grad before!" assert self.grads, "self.grads if empty!" for (name, grad) in self.grads.items(): self.add_histogram(tag=name, values=grad, global_step=global_step, *args, **kwargs) @staticmethod def make_desc(counter, total, *triples): desc = "[{}/{}]".format(counter, total) # The three elements of each triple are # (name to display, AverageMeter object, formatting string) for name, obj, fmt in triples: desc += (" {} {obj.val:" + fmt + "} ({obj.avg:" + fmt + "})").format(name, obj=obj) return desc
class Logger(object): def __init__(self, log_dir, save_tb=False, log_frequency=10000, agent='sac'): self._log_dir = log_dir self._log_frequency = log_frequency if save_tb: tb_dir = os.path.join(log_dir, 'tb') if os.path.exists(tb_dir): try: shutil.rmtree(tb_dir) except: print("logger.py warning: Unable to remove tb directory") pass self._sw = SummaryWriter(tb_dir) else: self._sw = None # each agent has specific output format for training assert agent in AGENT_TRAIN_FORMAT train_format = COMMON_TRAIN_FORMAT + AGENT_TRAIN_FORMAT[agent] self._train_mg = MetersGroup(os.path.join(log_dir, 'train'), formating=train_format) self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval'), formating=COMMON_EVAL_FORMAT) def _should_log(self, step, log_frequency): log_frequency = log_frequency or self._log_frequency return step % log_frequency == 0 def _try_sw_log(self, key, value, step): if self._sw is not None: self._sw.add_scalar(key, value, step) def _try_sw_log_image(self, key, image, step): if self._sw is not None: assert image.dim() == 3 grid = torchvision.utils.make_grid(image.unsqueeze(1)) self._sw.add_image(key, grid, step) def _try_sw_log_video(self, key, frames, step): if self._sw is not None: frames = torch.from_numpy(np.array(frames)) frames = frames.unsqueeze(0) self._sw.add_video(key, frames, step, fps=30) def _try_sw_log_histogram(self, key, histogram, step): if self._sw is not None: self._sw.add_histogram(key, histogram, step) def log(self, key, value, step, n=1, log_frequency=1): if not self._should_log(step, log_frequency): return assert key.startswith('train') or key.startswith('eval') if type(value) == torch.Tensor: value = value.item() self._try_sw_log(key, value / n, step) mg = self._train_mg if key.startswith('train') else self._eval_mg mg.log(key, value, n) def log_param(self, key, param, step, log_frequency=None): if not self._should_log(step, log_frequency): return self.log_histogram(key + '_w', param.weight.data, step) if hasattr(param.weight, 'grad') and param.weight.grad is not None: self.log_histogram(key + '_w_g', param.weight.grad.data, step) if hasattr(param, 'bias') and hasattr(param.bias, 'data'): self.log_histogram(key + '_b', param.bias.data, step) if hasattr(param.bias, 'grad') and param.bias.grad is not None: self.log_histogram(key + '_b_g', param.bias.grad.data, step) def log_image(self, key, image, step, log_frequency=None): if not self._should_log(step, log_frequency): return assert key.startswith('train') or key.startswith('eval') self._try_sw_log_image(key, image, step) def log_video(self, key, frames, step, log_frequency=None): if not self._should_log(step, log_frequency): return assert key.startswith('train') or key.startswith('eval') self._try_sw_log_video(key, frames, step) def log_histogram(self, key, histogram, step, log_frequency=None): if not self._should_log(step, log_frequency): return assert key.startswith('train') or key.startswith('eval') self._try_sw_log_histogram(key, histogram, step) def dump(self, step, save=True, ty=None): if ty is None: self._train_mg.dump(step, 'train', save) self._eval_mg.dump(step, 'eval', save) elif ty == 'eval': self._eval_mg.dump(step, 'eval', save) elif ty == 'train': self._train_mg.dump(step, 'train', save) else: raise f'invalid log type: {ty}'
class DisentanglingTrainer(LatentTrainer): def __init__( self, env, log_dir, state_rep, skill_policy_path, seed, run_id, feature_dim=5, num_sequences=80, cuda=False, ): parent_kwargs = dict( num_steps=3000000, initial_latent_steps=100000, batch_size=256, latent_batch_size=128, num_sequences=num_sequences, latent_lr=0.0001, feature_dim=feature_dim, # Note: Only used if state_rep == False latent1_dim=8, latent2_dim=32, mode_dim=2, hidden_units=[256, 256], hidden_units_decoder=[256, 256], hidden_units_mode_encoder=[256, 256], hidden_rnn_dim=64, rnn_layers=2, memory_size=1e5, leaky_slope=0.2, grad_clip=None, start_steps=10000, training_log_interval=100, learning_log_interval=50, cuda=cuda, seed=seed) # Other self.run_id = run_id self.state_rep = state_rep # Comment for summery writer summary_comment = self.run_id # Environment self.env = env self.observation_shape = self.env.observation_space.shape self.action_shape = self.env.action_space.shape self.action_repeat = self.env.action_repeat # Seed torch.manual_seed(parent_kwargs['seed']) np.random.seed(parent_kwargs['seed']) self.env.seed(parent_kwargs['seed']) # Device self.device = torch.device("cuda" if parent_kwargs['cuda'] and torch.cuda.is_available() else "cpu") # Latent Network self.latent = ModeDisentanglingNetwork( self.observation_shape, self.action_shape, feature_dim=parent_kwargs['feature_dim'], latent1_dim=parent_kwargs['latent1_dim'], latent2_dim=parent_kwargs['latent2_dim'], mode_dim=parent_kwargs['mode_dim'], hidden_units=parent_kwargs['hidden_units'], hidden_units_decoder=parent_kwargs['hidden_units_decoder'], hidden_units_mode_encoder=parent_kwargs[ 'hidden_units_mode_encoder'], rnn_layers=parent_kwargs['rnn_layers'], hidden_rnn_dim=parent_kwargs['hidden_rnn_dim'], leaky_slope=parent_kwargs['leaky_slope'], state_rep=state_rep).to(self.device) # Load pretrained DIAYN skill policy data = torch.load(skill_policy_path) self.policy = data['evaluation/policy'] print("Policy loaded") # MI-Gradient score estimators self.spectral_j = SpectralScoreEstimator(n_eigen_threshold=0.99) self.spectral_m = SpectralScoreEstimator(n_eigen_threshold=0.99) # Optimization self.latent_optim = Adam(self.latent.parameters(), lr=parent_kwargs['latent_lr']) # Memory self.memory = MyMemoryDisentangling( state_rep=state_rep, capacity=parent_kwargs['memory_size'], num_sequences=parent_kwargs['num_sequences'], observation_shape=self.observation_shape, action_shape=self.action_shape, device=self.device) # Log directories self.log_dir = log_dir self.model_dir = os.path.join(log_dir, 'model') self.summary_dir = os.path.join(log_dir, 'summary') self.images_dir = os.path.join(log_dir, 'images') if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) if not os.path.exists(self.summary_dir): os.makedirs(self.summary_dir) if not os.path.exists(self.images_dir): os.makedirs(self.images_dir) # Summary writer with conversion of hparams # (certain types are not aloud for hparam-storage) self.writer = SummaryWriter(os.path.join(self.summary_dir, summary_comment), filename_suffix=self.run_id) hparam_dict = parent_kwargs.copy() for k, v in hparam_dict.items(): if isinstance(v, type(None)): hparam_dict[k] = 'None' if isinstance(v, list): hparam_dict[k] = torch.Tensor(v) hparam_dict['hidden_units'] = torch.Tensor( parent_kwargs['hidden_units']) self.writer.add_hparams(hparam_dict=hparam_dict, metric_dict={}) # Set hyperparameters self.steps = 0 self.learning_steps = 0 self.episodes = 0 self.initial_latent_steps = parent_kwargs['initial_latent_steps'] self.num_sequences = num_sequences self.num_steps = parent_kwargs['num_steps'] self.batch_size = parent_kwargs['batch_size'] self.latent_batch_size = parent_kwargs['latent_batch_size'] self.start_steps = parent_kwargs['start_steps'] self.grad_clip = parent_kwargs['grad_clip'] self.training_log_interval = parent_kwargs['training_log_interval'] self.learning_log_interval = parent_kwargs['learning_log_interval'] # Mode action sampler self.mode_action_sampler = ModeActionSampler(self.latent, device=self.device) def get_skill_action_pixel(self): obs_state_space = self.env.get_state_obs() action, info = self.policy.get_action(obs_state_space) return action def get_skill_action_state_rep(self, observation): action, info = self.policy.get_action(observation) return action def set_policy_skill(self, skill): self.policy.stochastic_policy.skill = skill def train_episode(self): self.episodes += 1 episode_steps = 0 episode_reward = 0. done = False state = self.env.reset() self.memory.set_initial_state(state) skill = np.random.randint(self.policy.stochastic_policy.skill_dim) #skill = np.random.choice([3, 4, 5, 7, 9], 1).item() self.set_policy_skill(skill) next_state = state while not done and episode_steps <= self.num_sequences + 2: action = self.get_skill_action_state_rep(next_state) if self.state_rep \ else self.get_skill_action_pixel() next_state, reward, done, _ = self.env.step(action) self.steps += self.action_repeat episode_steps += self.action_repeat episode_reward += reward self.memory.append(action, skill, next_state, done) if self.is_update(): if self.learning_steps < self.initial_latent_steps: print('-' * 60) print('Learning the disentangled model only...') for _ in tqdm(range(self.initial_latent_steps)): self.learning_steps += 1 self.learn_latent() print('Finished learning the disentangled model') print('-' * 60) print(f'episode: {self.episodes:<4} ' f'episode steps: {episode_steps:<4} ' f'skill: {skill:<4} ') self.save_models() def learn_latent(self): # Sample sequence images_seq, actions_seq, skill_seq, dones_seq = \ self.memory.sample_latent(self.latent_batch_size) # Calc loss latent_loss = self.calc_latent_loss(images_seq, actions_seq, skill_seq, dones_seq) # Backprop update_params(self.latent_optim, self.latent, latent_loss, self.grad_clip) # Write net params if self._is_log(self.learning_log_interval * 5): self.latent.write_net_params(self.writer, self.learning_steps) def calc_latent_loss(self, images_seq, actions_seq, skill_seq, dones_seq): # Get features from images features_seq = self.latent.encoder(images_seq) # Sample from posterior dynamics ((latent1_post_samples, latent2_post_samples, mode_post_samples), (latent1_post_dists, latent2_post_dists, mode_post_dist)) = \ self.latent.sample_posterior(actions_seq=actions_seq, features_seq=features_seq) # Sample from prior dynamics ((latent1_pri_samples, latent2_pri_samples, mode_pri_samples), (latent1_pri_dists, latent2_pri_dists, mode_pri_dist)) = \ self.latent.sample_prior(features_seq) # KL divergence losses latent_kld = calc_kl_divergence(latent1_post_dists, latent1_pri_dists) latent1_dim = latent1_post_samples.size(2) seq_length = latent1_post_samples.size(1) latent_kld /= latent1_dim latent_kld /= seq_length mode_kld = calc_kl_divergence([mode_post_dist], [mode_pri_dist]) mode_dim = mode_post_samples.size(2) mode_kld /= mode_dim kld_losses = mode_kld + latent_kld # Log likelihood loss of generated actions actions_seq_dists = self.latent.decoder( latent1_sample=latent1_post_samples, latent2_sample=latent2_post_samples, mode_sample=mode_post_samples) log_likelihood = actions_seq_dists.log_prob(actions_seq).mean( dim=0).mean() mse = torch.nn.functional.mse_loss(actions_seq_dists.loc, actions_seq) # Log likelihood loss of generated actions with latent dynamic priors and mode # posterior actions_seq_dists_mode = self.latent.decoder( latent1_sample=latent1_pri_samples.detach(), latent2_sample=latent2_pri_samples.detach(), mode_sample=mode_post_samples) ll_dyn_pri_mode_post = actions_seq_dists_mode.\ log_prob(actions_seq).mean(dim=0).mean() # Log likelihood loss of generated actions with latent dynamic posteriors and # mode prior action_seq_dists_dyn = self.latent.decoder( latent1_sample=latent1_post_samples, latent2_sample=latent2_post_samples, mode_sample=mode_pri_samples) ll_dyn_post_mode_pri = action_seq_dists_dyn.log_prob(actions_seq).mean( dim=0).mean() # Maximum Mean Discrepancy (MMD) mode_pri_sample = mode_pri_samples[:, 0, :] mode_post_sample = mode_post_samples[:, 0, :] mmd_mode = self.compute_mmd_tutorial(mode_pri_sample, mode_post_sample) mmd_latent = 0 #latent1_post_samples_trans = latent1_post_samples.transpose(0, 1) #latent1_pri_samples_trans = latent1_pri_samples.transpose(0, 1) #for idx in range(latent1_post_samples_trans.size(0)): # mmd_latent += self.compute_mmd_tutorial(latent1_pri_samples_trans[idx], # latent1_post_samples_trans[idx]) latent1_post_samples_trans = latent1_post_samples.\ view(latent1_post_samples.size(0), -1) latent1_pri_samples_trans = latent1_pri_samples.\ view(latent1_pri_samples.size(0), -1) mmd_latent = self.compute_mmd_tutorial(latent1_pri_samples_trans, latent1_post_samples_trans) mmd_mode_weighted = mmd_mode mmd_latent_weighted = mmd_latent mmd_loss = mmd_latent_weighted + mmd_mode_weighted # MI-Gradient # m - data batch_size = mode_post_samples.size(0) features_actions_seq = torch.cat( [features_seq, actions_seq[:, :-1, :]], dim=2) xs = features_actions_seq.view(batch_size, -1) ys = mode_post_sample xs_ys = torch.cat([xs, ys], dim=1) gradient_estimator_m_data = entropy_surrogate(self.spectral_j, xs_ys) \ - entropy_surrogate(self.spectral_m, ys) # m_pri - gen_data xs = mode_pri_sample ys = actions_seq_dists.loc.view(batch_size, -1) xs_ys = torch.cat([xs, ys], dim=1) gradient_estimator_m_gendata = entropy_surrogate(self.spectral_j, xs_ys) \ - entropy_surrogate(self.spectral_m, ys) # m_post - latent_post #xs = mode_post_sample #gradient_estimator_mpost_latentpost = 0 #for idx in range(latent1_post_samples.size(1)): # ys = latent1_post_samples[:, idx, :] # xs_ys = torch.cat([xs, ys], dim=1) # single_estimator = entropy_surrogate(self.spectral_j, xs_ys) \ # - entropy_surrogate(self.spectral_m, ys) # gradient_estimator_mpost_latentpost += single_estimator xs = latent1_post_samples.view(batch_size, -1) ys = mode_post_sample xs_ys = torch.cat([xs, ys], dim=1) gradient_estimator_m_gendata = entropy_surrogate(self.spectral_j, xs_ys) \ - entropy_surrogate(self.spectral_m, ys) # m-post - z-post #xs = mode_post_sample #ys = torch.cat([latent1_post_samples.view(batch_size, -1), # latent2_post_samples.view(batch_size, -1)], dim=1) #xs_ys = torch.cat([xs, ys], dim=1) #gradient_estimator_m_post_z_post = entropy_surrogate(self.spectral_j, xs_ys) \ # - entropy_surrogate(self.spectral_m, ys) # Loss reg_weight = 1000. alpha = 0.99 kld_info_weighted = (1. - alpha) * kld_losses mmd_info_weighted = (alpha + reg_weight - 1.) * mmd_loss reg_weight_mode = 100. alpha_mode = 1. mmd_mode_info_weighted = \ (alpha_mode + reg_weight_mode - 1.) * mmd_mode_weighted kld_mode_info_weighted = (1. - alpha_mode) * mode_kld reg_weight_latent = 100. alpha_latent = 0 mmd_latent_info_weighted = \ (alpha_latent + reg_weight_latent -1.) * mmd_latent_weighted kld_latent_info_weighted = (1. - alpha_latent) * latent_kld loss_X = -log_likelihood loss_Z = kld_info_weighted + mmd_info_weighted latent_loss = mse + kld_losses - 1 * gradient_estimator_m_data #latent_loss = kld_info_weighted - log_likelihood + mmd_info_weighted #latent_loss = -log_likelihood \ # + 0.01 * latent_kld + mode_kld \ # - 1 * gradient_estimator_m_data \ # - 0 * gradient_estimator_m_gendata \ # #+ 1 * gradient_estimator_m_post_z_post #latent_loss = -log_likelihood + kld_info_weighted + mmd_info_weighted #latent_loss = mse\ # + mmd_mode_info_weighted \ # + kld_mode_info_weighted \ # + mmd_latent_info_weighted \ # + kld_latent_info_weighted \ # #+ 1 * gradient_estimator_mpost_latentpost latent_loss *= 10 # Logging if self._is_log(self.learning_log_interval): # Reconstruction error reconst_error = (actions_seq - actions_seq_dists.loc) \ .pow(2).mean(dim=(0, 1)).sum() self._summary_log('stats/reconst_error', reconst_error) print('reconstruction error: %f', reconst_error) reconst_err_mode_post = (actions_seq - actions_seq_dists_mode.loc)\ .pow(2).mean(dim=(0,1)).sum() reconst_err_dyn_post = (actions_seq - action_seq_dists_dyn.loc)\ .pow(2).mean(dim=(0, 1)).sum() self._summary_log('stats/reconst_error mode post', reconst_err_mode_post) self._summary_log('stats/reconst_error dyn post', reconst_err_dyn_post) # KL divergence mode_kldiv_standard = calc_kl_divergence([mode_post_dist], [mode_pri_dist]) seq_kldiv_standard = calc_kl_divergence(latent1_post_dists, latent1_pri_dists) kldiv_standard = mode_kldiv_standard + seq_kldiv_standard self._summary_log('stats_kldiv_standard/kldiv_standard', kldiv_standard) self._summary_log('stats_kldiv_standard/mode_kldiv_standard', mode_kldiv_standard) self._summary_log('stats_kldiv_standard/seq_kldiv_standard', seq_kldiv_standard) self._summary_log('stats_kldiv/mode_kldiv_used_for_loss', mode_kld) self._summary_log('stats_kldiv/latent_kldiv_used_for_loss', latent_kld) self._summary_log('stats_kldiv/klddiv used for loss', kld_losses) # Log Likelyhood self._summary_log('stats/log-likelyhood', log_likelihood) self._summary_log('stats/mse', mse) # MMD self._summary_log('stats_mmd/mmd_weighted', mmd_info_weighted) self._summary_log('stats_mmd/kld_weighted', kld_info_weighted) self._summary_log('stats_mmd/mmd_mode_weighted', mmd_mode_weighted) self._summary_log('stats_mmd/mmd_latent_weighted', mmd_latent_weighted) self._summary_log('stats_mmd_separated/mmd_mode_info_weighted', mmd_mode_info_weighted) self._summary_log('stats_mmd_separated/kld_mode_info_weighted', kld_mode_info_weighted) self._summary_log('stats_mmd_separated/mmd_latent_info_weighted', mmd_latent_info_weighted) self._summary_log('stats_mmd_separated/kld_latent_info_weighted', kld_latent_info_weighted) self._summary_log( 'stats_mmd_separated/loss_latentZ', mmd_latent_info_weighted + kld_latent_info_weighted) self._summary_log('stats_mmd_separated/loss_modeZ', mmd_mode_info_weighted + kld_mode_info_weighted) # MI-Grad self._summary_log('stats_mi/mi_grad_est m_pri generated data', gradient_estimator_m_gendata) self._summary_log('stats_mi/mi_grad_est m_post data', gradient_estimator_m_data) #self._summary_log('stats_mi/mi_grad_est m_post z_post', # gradient_estimator_m_post_z_post) # Loss self._summary_log('loss/network', latent_loss) # Save Model self.latent.save(os.path.join(self.model_dir, 'model.pth')) # Reconstruction Test rand_batch_idx = np.random.choice(actions_seq.size(0)) self._reconstruction_post_test(rand_batch_idx, actions_seq, actions_seq_dists, images_seq) self._reconstruction_mode_post_test( rand_batch_idx=rand_batch_idx, actions_seq=actions_seq, mode_post_samples=mode_post_samples, latent1_pri_samples=latent1_pri_samples, latent2_pri_samples=latent2_pri_samples) self._reconstruction_dyn_post_test(rand_batch_idx, actions_seq, latent1_post_samples, latent2_post_samples, mode_pri_samples) # Latent Test self._plot_latent_mode_map(skill_seq, mode_post_samples) self._gen_mode_grid_graph(mode_post_samples) # Mode influence test if self._is_log(self.learning_log_interval * 10): self._gen_mode_grid_videos(mode_post_samples) return latent_loss def _gen_mode_grid_videos(self, mode_post_samples): seq_len = 200 with torch.no_grad(): modes = self._create_grid(mode_post_samples) for (mode_idx, mode) in enumerate(modes): obs = self.env.reset() img = self.env.render() img_seq = torch.from_numpy(img.astype(np.float)).transpose( 0, -1).unsqueeze(0) self.mode_action_sampler.reset(mode=mode.unsqueeze(0)) for step in range(seq_len): action = self.mode_action_sampler( self.latent.encoder( torch.Tensor(obs.astype(np.float)).to( self.device).unsqueeze(0))) obs, _, done, _ = self.env.step( action.detach().cpu().numpy()[0]) img = self.env.render() img = torch.from_numpy(img.astype(np.float))\ .transpose(0, -1).unsqueeze(0) img_seq = torch.cat([img_seq, img], dim=0) self.writer.add_video('mode_generation_video/mode' + str(mode_idx), vid_tensor=img_seq.unsqueeze(0).float(), global_step=self.learning_steps) def _gen_mode_grid_graph(self, mode_post_samples): #TODO make the method universial in terms of envs assert len(self.env.action_space.shape) == 1,\ 'Method only works in MountainCar Case' seq_len = mode_post_samples.size(1) with torch.no_grad(): modes = self._create_grid(mode_post_samples) for (mode_idx, mode) in enumerate(modes): obs = self.env.reset() self.mode_action_sampler.reset(mode=mode.unsqueeze(0)) action = self.mode_action_sampler( self.latent.encoder( torch.from_numpy(obs).to( self.device).unsqueeze(0).float())) action = action.detach().cpu().numpy()[0] obs_save = np.expand_dims(obs, axis=0) actions_save = [action] for step in range(seq_len): obs, _, done, _ = self.env.step(action) action = self.mode_action_sampler( self.latent.encoder( torch.from_numpy(obs).to( self.device).unsqueeze(0).float())) action = action.detach().cpu().numpy()[0] actions_save = np.concatenate( (actions_save, np.expand_dims(action, axis=0)), axis=0) obs_save = np.concatenate( (obs_save, np.expand_dims(obs, axis=0)), axis=0) plt.interactive(False) axes = plt.gca() axes.set_ylim([-1.5, 1.5]) plt.plot(actions_save, label='actions') for dim in range(obs_save.shape[1]): plt.plot(obs_save[:, dim], label='state_dim' + str(dim)) fig = plt.gcf() self.writer.add_figure('mode_grid_plot_test/mode' + str(mode_idx), figure=fig, global_step=self.learning_steps) def _create_grid(self, mode_post_samples): mode_dim = mode_post_samples.size(2) grid_vec = torch.linspace(-2., 2., 4) grid_vec_list = [grid_vec] * mode_dim grid = torch.meshgrid(*grid_vec_list) modes = torch.stack(list(grid)).view(mode_dim, -1) \ .transpose(0, -1).to(self.device) # N x mode_dim return modes def _plot_latent_mode_map(self, skill_seq, mode_post_samples): with torch.no_grad(): images_seq, actions_seq, skill_seq, dones_seq = \ self.memory.sample_latent(128) features_seq = self.latent.encoder(images_seq) mode_post_dist = self.latent.mode_posterior( features_seq=features_seq.transpose(0, 1), actions_seq=actions_seq.transpose(0, 1)) mode_post_sample = mode_post_dist.rsample() if mode_post_sample.size(1) == 2: colors = [ 'b', 'g', 'r', 'c', 'm', 'y', 'k', 'darkorange', 'gray', 'lightgreen' ] skills = skill_seq.mean( dim=1).detach().cpu().squeeze().numpy().astype(np.int) plt.interactive(False) axes = plt.gca() axes.set_ylim([-3, 3]) axes.set_xlim([-3, 3]) #for (idx, skill) in enumerate(skills): # color = colors[skill.item()] # plt.scatter(mode_post_samples[idx, 0].detach().cpu().numpy(), # mode_post_samples[idx, 1].detach().cpu().numpy(), # label=skill, c=color) for skill in range(10): idx = skills == skill color = colors[skill] plt.scatter(mode_post_sample[idx, 0].detach().cpu().numpy(), mode_post_sample[idx, 1].detach().cpu().numpy(), label=skill, c=color) axes.legend() axes.grid(True) fig = plt.gcf() self.writer.add_figure('Latent_test/mode mapping', fig, global_step=self.learning_steps) def _reconstruction_post_test(self, rand_batch_idx, actions_seq, actions_seq_dists, states): """ Test reconstruction of inferred posterior Args: rand_batch_idx : which part of batch to use actions_seq : actions sequence sampled from data actions_seq_dists : distribution of inferred posterior actions (reconstruction) """ # Reconstruction test rand_batch = rand_batch_idx action_dim = actions_seq.size(2) gt_actions = actions_seq[rand_batch].detach().cpu() post_actions = actions_seq_dists.loc[rand_batch].detach().cpu() states = states[rand_batch].detach().cpu() for dim in range(action_dim): fig = self._reconstruction_test_plot(dim, gt_actions, post_actions, states) self.writer.add_figure('Reconst_post_test/reconst test dim' + str(dim), fig, global_step=self.learning_steps) plt.clf() def _reconstruction_mode_post_test(self, rand_batch_idx, actions_seq, latent1_pri_samples, latent2_pri_samples, mode_post_samples): """ Test if mode inference works Args: rand_batch_idx : which part of batch to use actions_seq : actions sequence sampled from data mode_post_samples : Sample from the inferred mode posterior distribution latent1_pri_samples : Samples from the dynamics prior latent2_pri_samples : Samples from the dynamics prior """ # Use random sample from batch rand_batch = rand_batch_idx # Decode actions_seq_dists = self.latent.decoder( latent1_sample=latent1_pri_samples[rand_batch], latent2_sample=latent2_pri_samples[rand_batch], mode_sample=mode_post_samples[rand_batch]) # Reconstruction test action_dim = actions_seq.size(2) gt_actions = actions_seq[rand_batch].detach().cpu() post_actions = actions_seq_dists.loc.detach().cpu() # Plot for dim in range(action_dim): fig = self._reconstruction_test_plot(dim, gt_actions, post_actions) self.writer.add_figure('Reconst_mode_post_test/reconst test dim' + str(dim), fig, global_step=self.learning_steps) def _reconstruction_dyn_post_test(self, rand_batch_idx, actions_seq, latent1_post_samples, latent2_post_samples, mode_pri_samples): """ Test the influence of the latent dynamics inference """ # Use random sample from batch rand_batch = rand_batch_idx # Decode actions_seq_dists = self.latent.decoder( latent1_sample=latent1_post_samples[rand_batch], latent2_sample=latent2_post_samples[rand_batch], mode_sample=mode_pri_samples[rand_batch]) # Reconstruction test action_dim = actions_seq.size(2) gt_actions = actions_seq[rand_batch].detach().cpu() post_actions = actions_seq_dists.loc.detach().cpu() # Plot for dim in range(action_dim): fig = self._reconstruction_test_plot(dim, gt_actions, post_actions) self.writer.add_figure('Reconst_dyn_post_test/reconst test dim' + str(dim), fig, global_step=self.learning_steps) def _reconstruction_test_plot(self, dim, gt_actions, post_actions, states=None): plt.interactive(False) axes = plt.gca() axes.set_ylim([-1.5, 1.5]) plt.plot(gt_actions[:, dim].numpy()) plt.plot(post_actions[:, dim].numpy()) if states is not None: for dim in range(states.size(1)): plt.plot(states[:, dim].numpy()) fig = plt.gcf() return fig def compute_kernel_tutorial(self, x, y): x_size = x.size(0) y_size = y.size(0) dim = x.size(1) x = x.unsqueeze(1) # (x_size, 1, dim) y = y.unsqueeze(0) # (1, y_size, dim) tiled_x = x.expand(x_size, y_size, dim) tiled_y = y.expand(x_size, y_size, dim) kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim) return torch.exp(-kernel_input) # (x_size, y_size) def compute_mmd_tutorial(self, x, y): assert x.shape == y.shape x_kernel = self.compute_kernel_tutorial(x, x) y_kernel = self.compute_kernel_tutorial(y, y) xy_kernel = self.compute_kernel_tutorial(x, y) mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean() return mmd def save_models(self): path_name = os.path.join(self.model_dir, self.run_id) self.latent.save(path_name + 'model_state_dict.pth') torch.save(self.latent, path_name + 'whole_model.pth') #np.save(self.memory, os.path.join(self.model_dir, 'memory.pth')) def _is_log(self, log_interval): return True if self.learning_steps % log_interval == 0 else False def _summary_log(self, data_name, data): if type(data) == torch.Tensor: data = data.detach().cpu().item() self.writer.add_scalar(data_name, data, self.learning_steps)
class SummaryWriter: def __init__(self, logdir, flush_secs=120): self.writer = TensorboardSummaryWriter( log_dir=logdir, purge_step=None, max_queue=10, flush_secs=flush_secs, filename_suffix='') self.global_step = None self.active = True # ------------------------------------------------------------------------ # register add_* and set_* functions in summary module on instantiation # ------------------------------------------------------------------------ this_module = sys.modules[__name__] list_of_names = dir(SummaryWriter) for name in list_of_names: # add functions (without the 'add' prefix) if name.startswith('add_'): setattr(this_module, name[4:], getattr(self, name)) # set functions if name.startswith('set_'): setattr(this_module, name, getattr(self, name)) def set_global_step(self, value): self.global_step = value def set_active(self, value): self.active = value def add_audio(self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_audio( tag, snd_tensor, global_step=global_step, sample_rate=sample_rate, walltime=walltime) def add_custom_scalars(self, layout): if self.active: self.writer.add_custom_scalars(layout) def add_custom_scalars_marginchart(self, tags, category='default', title='untitled'): if self.active: self.writer.add_custom_scalars_marginchart(tags, category=category, title=title) def add_custom_scalars_multilinechart(self, tags, category='default', title='untitled'): if self.active: self.writer.add_custom_scalars_multilinechart(tags, category=category, title=title) def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default', metadata_header=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_embedding( mat, metadata=metadata, label_img=label_img, global_step=global_step, tag=tag, metadata_header=metadata_header) def add_figure(self, tag, figure, global_step=None, close=True, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_figure( tag, figure, global_step=global_step, close=close, walltime=walltime) def add_graph(self, model, input_to_model=None, verbose=False): if self.active: self.writer.add_graph(model, input_to_model=input_to_model, verbose=verbose) def add_histogram(self, tag, values, global_step=None, bins='tensorflow', walltime=None, max_bins=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_histogram( tag, values, global_step=global_step, bins=bins, walltime=walltime, max_bins=max_bins) def add_histogram_raw(self, tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_histogram_raw( tag, min=min, max=max, num=num, sum=sum, sum_squares=sum_squares, bucket_limits=bucket_limits, bucket_counts=bucket_counts, global_step=global_step, walltime=walltime) def add_image(self, tag, img_tensor, global_step=None, walltime=None, dataformats='CHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_image( tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats) def add_image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None, walltime=None, rescale=1, dataformats='CHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_image_with_boxes( tag, img_tensor, box_tensor, global_step=global_step, walltime=walltime, rescale=rescale, dataformats=dataformats) def add_images(self, tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW'): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_images( tag, img_tensor, global_step=global_step, walltime=walltime, dataformats=dataformats) def add_mesh(self, tag, vertices, colors=None, faces=None, config_dict=None, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_mesh( tag, vertices, colors=colors, faces=faces, config_dict=config_dict, global_step=global_step, walltime=walltime) def add_onnx_graph(self, graph): if self.active: self.writer.add_onnx_graph(graph) def add_pr_curve(self, tag, labels, predictions, global_step=None, num_thresholds=127, weights=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_pr_curve( tag, labels, predictions, global_step=global_step, num_thresholds=num_thresholds, weights=weights, walltime=walltime) def add_pr_curve_raw(self, tag, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, global_step=None, num_thresholds=127, weights=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_pr_curve_raw( tag, true_positive_counts, false_positive_counts, true_negative_counts, false_negative_counts, precision, recall, global_step=global_step, num_thresholds=num_thresholds, weights=weights, walltime=walltime) def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_scalar( tag, scalar_value, global_step=global_step, walltime=walltime) def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_scalars( main_tag, tag_scalar_dict, global_step=global_step, walltime=walltime) def add_text(self, tag, text_string, global_step=None, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_text( tag, text_string, global_step=global_step, walltime=walltime) def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): if self.active: global_step = self.global_step if global_step is None else global_step self.writer.add_video( tag, vid_tensor, global_step=global_step, fps=fps, walltime=walltime) def close(self): self.writer.close() def __enter__(self): return self.writer.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): return self.writer.__exit__(exc_type, exc_val, exc_tb)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, default='data/bair') parser.add_argument('--model_path', type=str, default='model/bair') parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--horizon', type=int, default=10) parser.add_argument('--cpu_workers', type=int, default=4) parser.add_argument('--gpu_id', type=int, default=0) parser.add_argument('--model_name', type=str, default='cdna') parser.add_argument('--load_point', type=int, default=10) parser.add_argument('--no-gif', dest='save_gif', action='store_false') parser.set_defaults(save_gif=True) args = parser.parse_args() device = 'cuda:%d' % args.gpu_id if torch.cuda.device_count() > 0 else 'cpu' # dataset setup test_set = VideoDataset(args.data_path, 'test', args.horizon, fix_start=True) config = test_set.get_config() H, W, C = config['observations'] A = config['actions'][0] T = args.horizon test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.cpu_workers) # model setup if args.model_name == 'cdna': model = CDNA(T, H, W, C, A) elif args.model_name == 'etd': model = ETD(H, W, C, A, T, 5) elif args.model_name == 'etds': model = ETDS(H, W, C, A, T, 5) elif args.model_name == 'etdm': model = ETDM(H, W, C, A, T, 5) elif args.model_name == 'etdsd': model = ETDSD(H, W, C, A, T, 5) model.to(device) model_path = os.path.join(args.model_path, '{}_10'.format(args.model_name)) load_model(model, os.path.join(model_path, '{}_{}.pt'.format(args.model_name, args.load_point)), eval_mode=True) # tensorboard writer = SummaryWriter() gif_path = os.path.join(model_path, 'test_{}'.format(args.horizon)) if not os.path.exists(gif_path): os.makedirs(gif_path) losses = [] videos = [] inference_times = [] with torch.no_grad(): for j, data in enumerate(test_loader): observations = data['observations'] actions = data['actions'] # B x T ==> T x B observations = torch.transpose(observations, 0, 1).to(device) actions = torch.transpose(actions, 0, 1).to(device) start_time = time.time() predicted_observations = model(observations[0], actions) inference_times.append((time.time() - start_time) * 1000) video = torch.cat([observations[0, 0].unsqueeze(0), predicted_observations[0 : T - 1, 0]]) # tensor[T, C, H, W] videos.append(video.unsqueeze(0).detach().cpu()) if args.save_gif: torch_save_gif(os.path.join(gif_path, "{}.gif".format(j)), video.detach().cpu(), fps=10) loss = mse_loss(observations, predicted_observations).item() / args.batch_size losses.append(loss) del loss, observations, actions, predicted_observations # clear the memory videos = torch.cat(videos, 0) writer.add_video('test_video_{}_{}'.format(args.model_name, args.horizon), videos, global_step=0, fps=10) print("-" * 50) print("mean loss in test set is {}, std is {}".format(np.mean(losses), np.std(losses))) print("mean inference time in test set is {}, std is {}".format(np.mean(inference_times), np.std(inference_times))) print("-" * 50)
def plot_2d_or_3d_image( data: Union[NdarrayTensor, List[NdarrayTensor]], step: int, writer: SummaryWriter, index: int = 0, max_channels: int = 1, frame_dim: int = -3, max_frames: int = 24, tag: str = "output", ) -> None: """Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image. Note: Plot 3D or 2D image(with more than 3 channels) as separate images. And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video. Args: data: target data to be plotted as image on the TensorBoard. The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions, and only plot the first in the batch. step: current step to plot in a chart. writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image. index: plot which element in the input data batch, default is the first element. max_channels: number of channels to plot. frame_dim: if plotting 3D image as GIF, specify the dimension used as frames, expect input data shape as `NCHWD`, default to `-3` (the first spatial dim) max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. tag: tag of the plotted image on TensorBoard. """ data_index = data[index] # as the `d` data has no batch dim, reduce the spatial dim index if positive frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim d: np.ndarray = data_index.detach().cpu().numpy() if isinstance( data_index, torch.Tensor) else data_index if d.ndim == 2: d = rescale_array(d, 0, 1) # type: ignore dataformats = "HW" writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats) return if d.ndim == 3: if d.shape[0] == 3 and max_channels == 3: # RGB dataformats = "CHW" writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats) return dataformats = "HW" for j, d2 in enumerate(d[:max_channels]): d2 = rescale_array(d2, 0, 1) writer.add_image(f"{tag}_{dataformats}_{j}", d2, step, dataformats=dataformats) return if d.ndim >= 4: spatial = d.shape[-3:] d = d.reshape([-1] + list(spatial)) if d.shape[ 0] == 3 and max_channels == 3 and has_tensorboardx and isinstance( writer, SummaryWriterX): # RGB # move the expected frame dim to the end as `T` dim for video d = np.moveaxis(d, frame_dim, -1) writer.add_video(tag, d[None], step, fps=max_frames, dataformats="NCHWT") return # scale data to 0 - 255 for visualization max_channels = min(max_channels, d.shape[0]) d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]], axis=0) # will plot every channel as a separate GIF image add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, frame_dim=frame_dim, global_step=step) return
class Logger(object): def __init__(self, log_dir, use_tb=True, config="rl"): self._log_dir = log_dir if use_tb: tb_dir = os.path.join(log_dir, "tb") if os.path.exists(tb_dir): shutil.rmtree(tb_dir) self._sw = SummaryWriter(tb_dir) else: self._sw = None self._train_mg = MetersGroup(os.path.join(log_dir, "train.log"), formating=FORMAT_CONFIG[config]["train"]) self._eval_mg = MetersGroup(os.path.join(log_dir, "eval.log"), formating=FORMAT_CONFIG[config]["eval"]) def _try_sw_log(self, key, value, step): if self._sw is not None: self._sw.add_scalar(key, value, step) def _try_sw_log_image(self, key, image, step): if self._sw is not None: assert image.dim() == 3 grid = torchvision.utils.make_grid(image.unsqueeze(1)) self._sw.add_image(key, grid, step) def _try_sw_log_video(self, key, frames, step): if self._sw is not None: frames = torch.from_numpy(np.array(frames)) frames = frames.unsqueeze(0) self._sw.add_video(key, frames, step, fps=30) def _try_sw_log_histogram(self, key, histogram, step): if self._sw is not None: self._sw.add_histogram(key, histogram, step) def log(self, key, value, step, n=1): assert key.startswith("train") or key.startswith("eval") if type(value) == torch.Tensor: value = value.item() self._try_sw_log(key, value / n, step) mg = self._train_mg if key.startswith("train") else self._eval_mg mg.log(key, value, n) def log_param(self, key, param, step): self.log_histogram(key + "_w", param.weight.data, step) if hasattr(param.weight, "grad") and param.weight.grad is not None: self.log_histogram(key + "_w_g", param.weight.grad.data, step) if hasattr(param, "bias"): self.log_histogram(key + "_b", param.bias.data, step) if hasattr(param.bias, "grad") and param.bias.grad is not None: self.log_histogram(key + "_b_g", param.bias.grad.data, step) def log_image(self, key, image, step): assert key.startswith("train") or key.startswith("eval") self._try_sw_log_image(key, image, step) def log_video(self, key, frames, step): assert key.startswith("train") or key.startswith("eval") self._try_sw_log_video(key, frames, step) def log_histogram(self, key, histogram, step): assert key.startswith("train") or key.startswith("eval") self._try_sw_log_histogram(key, histogram, step) def dump(self, step): self._train_mg.dump(step, "train") self._eval_mg.dump(step, "eval")
# generate videos videos = [] for i in range(params['validation_samples']): pbar.set_description( f'Rendering rollout {i + 1}/{params["validation_samples"]}', refresh=True) initial = next(iter(validation_loader)) graph = initial[0] graph.to(device=device) pos = infer_trajectory(network, graph, params['validation_steps']) video = render_rollout_tensor(pos, title=f'step {step}') videos.append(video) videos = torch.cat(videos, dim=0) writer.add_video('rollout', videos, step, fps=60) pbar.set_description(refresh=True) if (step) % params['model_save_interval'] == 0: torch.save( { 'step': step, 'optimizer_state_dict': optimizer.state_dict(), 'model_state_dict': network.cpu().state_dict(), 'loss': loss }, join(args.save_dir, f'{args.run_name}-{step}.pt')) network.to(device=device) if step >= params['steps']: break
class Trainer(object): def __init__(self, image_sampler, video_sampler, log_interval, train_batches, log_folder, use_cuda=False, use_infogan=True, use_categories=True): self.use_categories = use_categories self.gan_criterion = nn.BCEWithLogitsLoss() self.category_criterion = nn.CrossEntropyLoss() self.image_sampler = image_sampler self.video_sampler = video_sampler self.video_batch_size = self.video_sampler.batch_size self.image_batch_size = self.image_sampler.batch_size self.log_interval = log_interval self.train_batches = train_batches self.log_folder = log_folder self.use_cuda = use_cuda self.use_infogan = use_infogan self.image_enumerator = None self.video_enumerator = None self.writer = SummaryWriter(self.log_folder) @staticmethod def ones_like(tensor, val=1.): return Variable(T.FloatTensor(tensor.size()).fill_(val), requires_grad=False) @staticmethod def zeros_like(tensor, val=0.): return Variable(T.FloatTensor(tensor.size()).fill_(val), requires_grad=False) def compute_gan_loss(self, discriminator, sample_true, sample_fake, is_video): real_batch = sample_true() batch_size = real_batch['images'].size(0) fake_batch, generated_categories = sample_fake(batch_size) real_labels, real_categorical = discriminator( Variable(real_batch['images'])) fake_labels, fake_categorical = discriminator(fake_batch) fake_gt, real_gt = self.get_gt_for_discriminator(batch_size, real=0.) l_discriminator = self.gan_criterion(real_labels, real_gt) + \ self.gan_criterion(fake_labels, fake_gt) # update image discriminator here # sample again for videos # update video discriminator # sample again # - videos # - images # l_vidoes + l_images -> l # l.backward() # opt.step() # sample again and compute for generator fake_gt = self.get_gt_for_generator(batch_size) # to real_gt l_generator = self.gan_criterion(fake_labels, fake_gt) if is_video: # Ask the video discriminator to learn categories from training videos categories_gt = Variable( torch.squeeze(real_batch['categories'].long())) l_discriminator += self.category_criterion(real_categorical, categories_gt) if self.use_infogan: # Ask the generator to generate categories recognizable by the discriminator l_generator += self.category_criterion(fake_categorical, generated_categories) return l_generator, l_discriminator def sample_real_image_batch(self): if self.image_enumerator is None: self.image_enumerator = enumerate(self.image_sampler) batch_idx, batch = next(self.image_enumerator) b = batch if self.use_cuda: for k, v in batch.items(): b[k] = v.cuda() if batch_idx == len(self.image_sampler) - 1: self.image_enumerator = enumerate(self.image_sampler) return b def sample_real_video_batch(self): if self.video_enumerator is None: self.video_enumerator = enumerate(self.video_sampler) batch_idx, batch = next(self.video_enumerator) b = batch if self.use_cuda: for k, v in batch.items(): b[k] = v.cuda() if batch_idx == len(self.video_sampler) - 1: self.video_enumerator = enumerate(self.video_sampler) return b def train_discriminator(self, discriminator, sample_true, sample_fake, opt, batch_size, use_categories): opt.zero_grad() real_batch = sample_true() batch = Variable(real_batch['images'], requires_grad=False) # util.show_batch(batch.data) fake_batch, generated_categories = sample_fake(batch_size) real_labels, real_categorical = discriminator(batch) fake_labels, fake_categorical = discriminator(fake_batch.detach()) ones = self.ones_like(real_labels) zeros = self.zeros_like(fake_labels) l_discriminator = self.gan_criterion(real_labels, ones) + \ self.gan_criterion(fake_labels, zeros) if use_categories: # Ask the video discriminator to learn categories from training videos categories_gt = Variable(torch.squeeze( real_batch['categories'].long()), requires_grad=False) l_discriminator += self.category_criterion( real_categorical.squeeze(), categories_gt) l_discriminator.backward() opt.step() return l_discriminator def train_generator(self, image_discriminator, video_discriminator, sample_fake_images, sample_fake_videos, opt): opt.zero_grad() # train on images fake_batch, generated_categories = sample_fake_images( self.image_batch_size) fake_labels, fake_categorical = image_discriminator(fake_batch) all_ones = self.ones_like(fake_labels) l_generator = self.gan_criterion(fake_labels, all_ones) # train on videos fake_batch, generated_categories = sample_fake_videos( self.video_batch_size) fake_labels, fake_categorical = video_discriminator(fake_batch) all_ones = self.ones_like(fake_labels) l_generator += self.gan_criterion(fake_labels, all_ones) if self.use_infogan: # Ask the generator to generate categories recognizable by the discriminator l_generator += self.category_criterion(fake_categorical.squeeze(), generated_categories) l_generator.backward() opt.step() return l_generator def train(self, generator, image_discriminator, video_discriminator): if self.use_cuda: generator.cuda() image_discriminator.cuda() video_discriminator.cuda() # create optimizers opt_generator = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001) opt_image_discriminator = optim.Adam(image_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001) opt_video_discriminator = optim.Adam(video_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001) # training loop def sample_fake_image_batch(batch_size): return generator.sample_images(batch_size) def sample_fake_video_batch(batch_size): return generator.sample_videos(batch_size) def init_logs(): return {'l_gen': 0, 'l_image_dis': 0, 'l_video_dis': 0} batch_num = 0 logs = init_logs() start_time = time.time() while True: generator.train() image_discriminator.train() video_discriminator.train() opt_generator.zero_grad() opt_video_discriminator.zero_grad() # train image discriminator l_image_dis = self.train_discriminator( image_discriminator, self.sample_real_image_batch, sample_fake_image_batch, opt_image_discriminator, self.image_batch_size, use_categories=False) # train video discriminator l_video_dis = self.train_discriminator( video_discriminator, self.sample_real_video_batch, sample_fake_video_batch, opt_video_discriminator, self.video_batch_size, use_categories=self.use_categories) # train generator l_gen = self.train_generator(image_discriminator, video_discriminator, sample_fake_image_batch, sample_fake_video_batch, opt_generator) logs['l_gen'] += l_gen.data logs['l_image_dis'] += l_image_dis.data logs['l_video_dis'] += l_video_dis.data batch_num += 1 if batch_num % self.log_interval == 0: log_string = "Batch %d" % batch_num for k, v in logs.items(): log_string += " [%s] %5.3f" % (k, v / self.log_interval) log_string += ". Took %5.2f" % (time.time() - start_time) print(log_string) # log loss for tag, value in list(logs.items()): self.writer.add_scalar(tag, value / self.log_interval, batch_num) logs = init_logs() start_time = time.time() generator.eval() images, _ = sample_fake_image_batch(self.image_batch_size) # log images self.writer.add_images('images-objs', images[:, 0:3, :, :], batch_num) self.writer.add_images('images-background', images[:, 3:6, :, :], batch_num) videos, _ = sample_fake_video_batch(self.video_batch_size) # log videos vid_obj = videos[:, 0:3, :, :, :].permute(0, 2, 1, 3, 4) vid_background = videos[:, 3:6, :, :, :].permute(0, 2, 1, 3, 4) self.writer.add_video("video_obj", vid_obj, batch_num, fps=35) self.writer.add_video("video_background", vid_background, batch_num, fps=35) torch.save( generator, os.path.join(self.log_folder, 'generator_%05d.pytorch' % batch_num)) if batch_num >= self.train_batches: torch.save( generator, os.path.join(self.log_folder, 'generator_%05d.pytorch' % batch_num)) break
class Logger: def __init__(self, log_dir, n_logged_samples=10, summary_writer=None): self._log_dir = log_dir print('########################') print('logging outputs to ', log_dir) print('########################') self._n_logged_samples = n_logged_samples self._summ_writer = SummaryWriter(log_dir, flush_secs=1, max_queue=1) def log_scalar(self, scalar, name, step_): self._summ_writer.add_scalar('{}'.format(name), scalar, step_) def log_scalars(self, scalar_dict, group_name, step, phase): """Will log all scalars in the same plot.""" self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step) def log_image(self, image, name, step): assert (len(image.shape) == 3) # [C, H, W] self._summ_writer.add_image('{}'.format(name), image, step) def log_video(self, video_frames, name, step, fps=10): assert len( video_frames.shape ) == 5, "Need [N, T, C, H, W] input tensor for video logging!" self._summ_writer.add_video('{}'.format(name), video_frames, step, fps=fps) def log_paths_as_videos(self, paths, step, max_videos_to_save=2, fps=10, video_title='video'): # reshape the rollouts videos = [np.transpose(p['image_obs'], [0, 3, 1, 2]) for p in paths] # max rollout length max_videos_to_save = np.min([max_videos_to_save, len(videos)]) max_length = videos[0].shape[0] for i in range(max_videos_to_save): if videos[i].shape[0] > max_length: max_length = videos[i].shape[0] # pad rollouts to all be same length for i in range(max_videos_to_save): if videos[i].shape[0] < max_length: padding = np.tile([videos[i][-1]], (max_length - videos[i].shape[0], 1, 1, 1)) videos[i] = np.concatenate([videos[i], padding], 0) # log videos to tensorboard event file print("Logging videos") videos = np.stack(videos[:max_videos_to_save], 0) self.log_video(videos, video_title, step, fps=fps) def log_figures(self, figure, name, step, phase): """figure: matplotlib.pyplot figure handle""" assert figure.shape[ 0] > 0, "Figure logging requires input shape [batch x figures]!" self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step) def log_figure(self, figure, name, step, phase): """figure: matplotlib.pyplot figure handle""" self._summ_writer.add_figure('{}_{}'.format(name, phase), figure, step) def log_graph(self, array, name, step, phase): """figure: matplotlib.pyplot figure handle""" im = plot_graph(array) self._summ_writer.add_image('{}_{}'.format(name, phase), im, step) def dump_scalars(self, log_path=None): log_path = os.path.join( self._log_dir, "scalar_data.json") if log_path is None else log_path self._summ_writer.export_scalars_to_json(log_path) def flush(self): self._summ_writer.flush()
class Trainer: def __init__(self, save_name='train', description="Default model trainer", drop_last=False, allow_val_grad=False): now = datetime.datetime.now() parser = argparse.ArgumentParser(description=description) parser.add_argument('experiment_file', type=str, help='path to YAML experiment config file') parser.add_argument( '--save_path', type=str, default='', help= 'path to place model save file in during training (overwrites config)' ) parser.add_argument('--device', type=int, default=None, nargs='+', help='target device (uses all if not specified)') args = parser.parse_args() self._config = parse_basic_config(args.experiment_file) save_config = copy.deepcopy(self._config) if args.save_path: self._config['save_path'] = args.save_path # initialize device def_device = 0 if args.device is None else args.device[0] self._device = torch.device("cuda:{}".format(def_device)) self._device_list = args.device self._allow_val_grad = allow_val_grad # parse dataset class and create train/val loaders dataset_class = get_dataset(self._config['dataset'].pop('type')) dataset = dataset_class(**self._config['dataset'], mode='train') val_dataset = dataset_class(**self._config['dataset'], mode='val') self._train_loader = DataLoader( dataset, batch_size=self._config['batch_size'], shuffle=True, num_workers=self._config.get('loader_workers', cpu_count()), drop_last=drop_last, worker_init_fn=lambda w: np.random.seed( np.random.randint(2**29) + w)) self._val_loader = DataLoader( val_dataset, batch_size=self._config['batch_size'], shuffle=True, num_workers=min(1, self._config.get('loader_workers', cpu_count())), drop_last=True, worker_init_fn=lambda w: np.random.seed( np.random.randint(2**29) + w)) # set of file saving save_dir = os.path.join( self._config.get('save_path', './'), '{}_ckpt-{}-{}_{}-{}-{}'.format(save_name, now.hour, now.minute, now.day, now.month, now.year)) save_dir = os.path.expanduser(save_dir) if not os.path.exists(save_dir): os.makedirs(save_dir) with open(os.path.join(save_dir, 'config.yaml'), 'w') as f: yaml.dump(save_config, f, default_flow_style=False) self._writer = SummaryWriter(log_dir=os.path.join(save_dir, 'log')) self._save_fname = os.path.join(save_dir, 'model_save') self._step = None @property def config(self): return copy.deepcopy(self._config) def train(self, model, train_fn, weights_fn=None, val_fn=None, save_fn=None, optim_weights=None): # wrap model in DataParallel if needed and transfer to correct device if self.device_count > 1: model = nn.DataParallel(model, device_ids=self.device_list) model = model.to(self._device) # initializer optimizer and lr scheduler optim_weights = optim_weights if optim_weights is not None else model.parameters( ) optimizer, scheduler = self._build_optimizer_and_scheduler( optim_weights) # initialize constants: epochs = self._config.get('epochs', 1) vlm_alpha = self._config.get('vlm_alpha', 0.6) log_freq = self._config.get('log_freq', 20) self._img_log_freq = img_log_freq = self._config.get( 'img_log_freq', 500) assert img_log_freq % log_freq == 0, "log_freq must divide img_log_freq!" save_freq = self._config.get('save_freq', 5000) if val_fn is None: val_fn = train_fn self._step = 0 train_stats = {'loss': 0} val_iter = iter(self._val_loader) vl_running_mean = None for e in range(epochs): for inputs in self._train_loader: self._zero_grad(optimizer) loss_i, stats_i = train_fn(model, self._device, *inputs) self._step_optim(loss_i, self._step, optimizer) # calculate iter stats mod_step = self._step % log_freq train_stats['loss'] = (self._loss_to_scalar(loss_i) + mod_step * train_stats['loss']) / ( mod_step + 1) for k, v in stats_i.items(): if isinstance(v, torch.Tensor): assert len( v.shape) >= 4, "assumes 4dim BCHW image tensor!" train_stats[k] = v if k not in train_stats: train_stats[k] = 0 train_stats[k] = (v + mod_step * train_stats[k]) / ( mod_step + 1) if mod_step == 0: try: val_inputs = next(val_iter) except StopIteration: val_iter = iter(self._val_loader) val_inputs = next(val_iter) if self._allow_val_grad: model = model.eval() val_loss, val_stats = val_fn(model, self._device, *val_inputs) model = model.train() val_loss = self._loss_to_scalar(val_loss) else: with torch.no_grad(): model = model.eval() val_loss, val_stats = val_fn( model, self._device, *val_inputs) model = model.train() val_loss = self._loss_to_scalar(val_loss) # update running mean stat if vl_running_mean is None: vl_running_mean = val_loss vl_running_mean = val_loss * vlm_alpha + vl_running_mean * ( 1 - vlm_alpha) self._writer.add_scalar('loss/val', val_loss, self._step) for stats_dict, mode in zip([train_stats, val_stats], ['train', 'val']): for k, v in stats_dict.items(): if isinstance(v, torch.Tensor ) and self.step % img_log_freq == 0: if len(v.shape) == 5: self._writer.add_video( '{}/{}'.format(k, mode), v.cpu(), self._step) else: v_grid = torchvision.utils.make_grid( v.cpu(), padding=5) self._writer.add_image( '{}/{}'.format(k, mode), v_grid, self._step) elif not isinstance(v, torch.Tensor): self._writer.add_scalar( '{}/{}'.format(k, mode), v, self._step) # add learning rate parameter to log lrs = np.mean([p['lr'] for p in optimizer.param_groups]) self._writer.add_scalar('lr', lrs, self._step) # flush to disk and print self._writer.file_writer.flush() print( 'epoch {3}/{4}, step {0}: loss={1:.4f} \t val loss={2:.4f}' .format(self._step, train_stats['loss'], vl_running_mean, e, epochs)) else: print('step {0}: loss={1:.4f}'.format( self._step, train_stats['loss']), end='\r') self._step += 1 if self._step % save_freq == 0: if save_fn is not None: save_fn(self._save_fname, self._step) else: save_module = model if weights_fn is not None: save_module = weights_fn() elif isinstance(model, nn.DataParallel): save_module = model.module torch.save( save_module, self._save_fname + '-{}.pt'.format(self._step)) if self._config.get('save_optim', False): torch.save( optimizer.state_dict(), self._save_fname + '-optim-{}.pt'.format(self._step)) scheduler.step(val_loss=vl_running_mean) @property def device_count(self): if self._device_list is None: return torch.cuda.device_count() return len(self._device_list) @property def device_list(self): if self._device_list is None: return [i for i in range(torch.cuda.device_count())] return copy.deepcopy(self._device_list) @property def device(self): return copy.deepcopy(self._device) def _build_optimizer_and_scheduler(self, optim_weights): optimizer = torch.optim.Adam(optim_weights, self._config['lr'], weight_decay=self._config.get( 'weight_decay', 0)) return optimizer, build_scheduler(optimizer, self._config.get('lr_schedule', {})) def _step_optim(self, loss, step, optimizer): loss.backward() optimizer.step() def _zero_grad(self, optimizer): optimizer.zero_grad() def _loss_to_scalar(self, loss): return loss.item() @property def step(self): if self._step is None: raise Exception("Optimization has not begun!") return self._step @property def is_img_log_step(self): return self._step % self._img_log_freq == 0