def __init__(self, options): self.options = options self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self._init_fn() self.saver = CheckpointSaver(save_dir=options.checkpoint_dir) self.summary_writer = SummaryWriter(self.options.summary_dir) self.checkpoint = None if self.options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict, checkpoint_file=self.options.checkpoint) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count'] # self.lr_schedulers = {k: torch.optim.lr_scheduler.ReduceLROnPlateau(v, patience=5) # for k,v in self.optimizers_dict.items()} self.lr_schedulers = {k: torch.optim.lr_scheduler.ExponentialLR(v, gamma=self.options.lr_decay, last_epoch=self.epoch_count-1)\ for k,v in self.optimizers_dict.items()} for opt in self.optimizers_dict: self.lr_schedulers[opt].step()
def __init__(self, options): self.options = options self.endtime = time.time() + self.options.time_to_run self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self.init_fn() self.saver = CheckpointSaver(save_dir=options.checkpoint_dir) self.summary_writer = SummaryWriter(self.options.summary_dir) self.checkpoint = None if self.options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict, checkpoint_file=self.options.checkpoint) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count']
def __init__(self, options): self.options = options if options.multiprocessing_distributed: self.device = torch.device('cuda', options.gpu) else: self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self.saver = CheckpointSaver(save_dir=options.checkpoint_dir, overwrite=options.overwrite) if options.rank == 0: self.summary_writer = SummaryWriter(self.options.summary_dir) self.init_fn() self.checkpoint = None if options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count'] if self.checkpoint is not None: self.checkpoint_batch_idx = self.checkpoint['batch_idx'] else: self.checkpoint_batch_idx = 0 self.best_performance = float('inf')
def __init__(self, options): self.options = options self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') test_transform_list = [] if self.options.max_scale > 1: test_transform_list.append( RandomRescaleBB(1.0, self.options.max_scale)) test_transform_list.append( CropAndResize(out_size=(self.options.crop_size, self.options.crop_size))) test_transform_list.append( LocsToHeatmaps(out_size=(self.options.heatmap_size, self.options.heatmap_size))) test_transform_list.append(ToTensor()) test_transform_list.append(Normalize()) self.test_ds = RctaDataset( root_dir=self.options.dataset_dir, is_train=False, transform=transforms.Compose(test_transform_list)) self.model = StackedHourglass(self.options.num_keypoints).to( self.device) # Only create optimizer because it is required to restore from checkpoint self.optimizer = torch.optim.RMSprop(params=self.model.parameters(), lr=0, momentum=0, weight_decay=0) self.models_dict = {'stacked_hg': self.model} self.optimizers_dict = {'optimizer': self.optimizer} print("log dir:", options.log_dir) print("checkpoint dir:", options.checkpoint_dir) self.saver = CheckpointSaver(save_dir=options.checkpoint_dir) print("checkpoint:", self.options.checkpoint) self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict, checkpoint_file=self.options.checkpoint) self.criterion = nn.MSELoss().to(self.device) self.pose = Pose2DEval(detection_thresh=self.options.detection_thresh, dist_thresh=self.options.dist_thresh)
class BaseTrainer: def __init__(self, options): self.options = options self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self._init_fn() self.saver = CheckpointSaver(save_dir=options.checkpoint_dir) self.summary_writer = SummaryWriter(self.options.summary_dir) self.checkpoint = None if self.options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict, checkpoint_file=self.options.checkpoint) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count'] # self.lr_schedulers = {k: torch.optim.lr_scheduler.ReduceLROnPlateau(v, patience=5) # for k,v in self.optimizers_dict.items()} self.lr_schedulers = {k: torch.optim.lr_scheduler.ExponentialLR(v, gamma=self.options.lr_decay, last_epoch=self.epoch_count-1)\ for k,v in self.optimizers_dict.items()} for opt in self.optimizers_dict: self.lr_schedulers[opt].step() def _init_fn(self): raise NotImplementedError('You need to provide an _init_fn method') # @profile def train(self): self.endtime = time.time() + self.options.time_to_run for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), total=self.options.num_epochs, initial=self.epoch_count): train_data_loader = CheckpointDataLoader( self.train_ds, checkpoint=self.checkpoint, batch_size=self.options.batch_size, num_workers=self.options.num_workers, pin_memory=self.options.pin_memory, shuffle=self.options.shuffle_train) for step, batch in enumerate( tqdm(train_data_loader, desc='Epoch ' + str(epoch), total=math.ceil( len(self.train_ds) / self.options.batch_size), initial=train_data_loader.checkpoint_batch_idx), train_data_loader.checkpoint_batch_idx): #if epoch == 1: #step == 74 or step == 73: #from IPython.core.debugger import Pdb #Pdb().set_trace() # print("Epoch", epoch, "Step", step) # print(batch['keypoint_locs']) if time.time() < self.endtime: batch = {k: v.to(self.device) for k, v in batch.items()} out = self._train_step(batch) self.step_count += 1 if self.step_count % self.options.summary_steps == 0: try: self._train_summaries(batch, *out) except: from IPython.core.debugger import Pdb Pdb().set_trace() if self.step_count % self.options.checkpoint_steps == 0: self.saver.save_checkpoint( self.models_dict, self.optimizers_dict, epoch, step + 1, self.options.batch_size, train_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Checkpoint saved') if self.step_count % self.options.test_steps == 0: val_loss = self.test() # for opt in self.optimizers_dict: # self.lr_schedulers[opt].step(val_loss) else: tqdm.write('Timeout reached') self.saver.save_checkpoint( self.models_dict, self.optimizers_dict, epoch, step, self.options.batch_size, train_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Checkpoint saved') sys.exit(0) # apply the learning rate scheduling policy for opt in self.optimizers_dict: self.lr_schedulers[opt].step() # load a checkpoint only on startup, for the next epochs # just iterate over the dataset as usual self.checkpoint = None # save checkpoint after each epoch if (epoch + 1) % 10 == 0: # self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch+1, 0, self.step_count) self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch + 1, 0, self.options.batch_size, None, self.step_count) return def _get_lr(self): return next(iter(self.optimizers_dict.values())).param_groups[0]['lr'] # return next(iter(self.lr_schedulers.values())).get_lr()[0] def _train_step(self, input_batch): raise NotImplementedError('You need to provide a _train_step method') def _train_summaries(self, input_batch): raise NotImplementedError( 'You need to provide a _save_summaries method') def test(self, input_batch): raise NotImplementedError('You need to provide a _test_step method')
class BaseTrainer(object): """Base class for Trainer objects. Takes care of checkpointing/logging/resuming training. """ def __init__(self, options): self.options = options self.endtime = time.time() + self.options.time_to_run self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self.init_fn() self.saver = CheckpointSaver(save_dir=options.checkpoint_dir) self.checkpoint = None if self.options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint(self.models_dict, self.optimizers_dict, checkpoint_file=self.options.checkpoint) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count'] def load_pretrained(self, checkpoint_file=None): """Load a pretrained checkpoint. This is different from resuming training using --resume. """ if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file) for model in self.models_dict: if model in checkpoint: state_dict = checkpoint[model] renamed_state_dict = OrderedDict() # change the names in the state_dict to match the new layer for key, value in state_dict.items(): if 'layer' in key: names = key.split('.') names[1:1] = ['hmr_layer'] new_key = '.'.join(n for n in names) renamed_state_dict[new_key] = value else: renamed_state_dict[key] = value self.models_dict[model].load_state_dict(renamed_state_dict, strict=False) @staticmethod def linear_rampup(current, rampup_length): """Linear rampup""" assert current >= 0 and rampup_length >= 0 if current >= rampup_length: return 1.0 else: return current / rampup_length def train(self): """Training process.""" ramp_step = 0 # Run training for num_epochs epochs for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), total=self.options.num_epochs, initial=self.epoch_count): # ------------------ update image size intervals ---------------------- self.train_ds.update_size_intervals(epoch) # --------------------------------------------------------------------- # ------------------ update batch size ---------------------- if epoch == 0: batch_size = self.options.batch_size # 24 elif epoch == 1: batch_size = self.options.batch_size // 2 # 12 else: batch_size = self.options.batch_size // 3 # 8 if epoch == 3: self.options.checkpoint_steps = 2000 # --------------------------------------------------------------------- # Create new DataLoader every epoch and (possibly) resume from an arbitrary step inside an epoch train_data_loader = CheckpointDataLoader(self.train_ds, checkpoint=self.checkpoint, batch_size=batch_size, num_workers=self.options.num_workers, pin_memory=self.options.pin_memory, shuffle=self.options.shuffle_train) # init alphas if epoch <= 3: self.model.init_alphas(epoch+1, self.device) # Iterate over all batches in an epoch for step, batch in enumerate(tqdm(train_data_loader, desc='Epoch '+str(epoch), total=len(self.train_ds) // batch_size, initial=train_data_loader.checkpoint_batch_idx), train_data_loader.checkpoint_batch_idx): # ------------------ ramp consistency loss weight after updating the scale interval ---------------------- if self.options.ramp == 'up': total_ramp = (len(self.train_ds) // self.options.batch_size) * 5 self.consistency_loss_ramp = self.linear_rampup(ramp_step, total_ramp) ramp_step += 1 elif self.options.ramp == 'down': total_ramp = (len(self.train_ds) // self.options.batch_size) * 5 consistency_loss_ramp = self.linear_rampup(ramp_step, total_ramp) self.consistency_loss_ramp = 1.0 - consistency_loss_ramp ramp_step += 1 else: self.consistency_loss_ramp = 1.0 # --------------------------------------------------------------------- if time.time() < self.endtime: batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) and k != 'sample_index' else v for k,v in batch.items()} out = self.train_step(batch) self.step_count += 1 # Save checkpoint every checkpoint_steps steps if self.step_count % self.options.checkpoint_steps == 0 and epoch >= 3: self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch, step+1, self.options.batch_size, train_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Checkpoint saved') else: tqdm.write('Timeout reached') self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch, step, self.options.batch_size, train_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Checkpoint saved') sys.exit(0) # for the first 3 epochs, we only train half epoch if epoch == 0: if (step + 1) == (len(self.train_ds) // (self.options.batch_size * 2)): break elif epoch == 1: if (step + 1) == (len(self.train_ds) // self.options.batch_size): break elif epoch == 2: if (step + 1) == (len(self.train_ds) // (self.options.batch_size * 2)) * 3: break # load a checkpoint only on startup, for the next epochs # just iterate over the dataset as usual self.checkpoint=None # update learning rate if lr scheduler is epoch-based if self.lr_scheduler is not None and isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ExponentialLR): if (epoch + 1) % 4 == 0: self.lr_scheduler.step() return # The following methods (with the possible exception of test) have to be implemented in the derived classes def init_fn(self): raise NotImplementedError('You need to provide an _init_fn method') def train_step(self, input_batch): raise NotImplementedError('You need to provide a train_step method') def train_summaries(self, input_batch): raise NotImplementedError('You need to provide a _train_summaries method') def test(self): pass
class BaseTrainer(object): """Base class for Trainer objects. Takes care of checkpointing/logging/resuming training. """ def __init__(self, options): self.options = options self.endtime = time.time() + self.options.time_to_run self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self.init_fn() self.saver = CheckpointSaver(save_dir=options.checkpoint_dir) self.summary_writer = SummaryWriter(self.options.summary_dir) self.checkpoint = None if self.options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict, checkpoint_file=self.options.checkpoint) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count'] def load_pretrained(self, checkpoint_file=None): """Load a pretrained checkpoint. This is different from resuming training using --resume. """ if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file) for model in self.models_dict: if model in checkpoint: self.models_dict[model].load_state_dict(checkpoint[model]) print('Checkpoint loaded') def train(self): """Training process.""" # Run training for num_epochs epochs for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), total=self.options.num_epochs, initial=self.epoch_count): # Create new DataLoader every epoch and (possibly) resume from an arbitrary step inside an epoch train_data_loader = CheckpointDataLoader( self.train_ds, checkpoint=self.checkpoint, batch_size=self.options.batch_size, num_workers=self.options.num_workers, pin_memory=self.options.pin_memory, shuffle=self.options.shuffle_train) # Iterate over all batches in an epoch for step, batch in enumerate( tqdm(train_data_loader, desc='Epoch ' + str(epoch), total=len(self.train_ds) // self.options.batch_size, initial=train_data_loader.checkpoint_batch_idx), train_data_loader.checkpoint_batch_idx): if time.time() < self.endtime: batch = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items() } out = self.train_step(batch) self.step_count += 1 # Tensorboard logging every summary_steps steps if self.step_count % self.options.summary_steps == 0: self.train_summaries(batch, *out) # Save checkpoint every checkpoint_steps steps if self.step_count % self.options.checkpoint_steps == 0: self.saver.save_checkpoint( self.models_dict, self.optimizers_dict, epoch, step + 1, self.options.batch_size, train_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Checkpoint saved') # Run validation every test_steps steps if self.step_count % self.options.test_steps == 0: self.test() else: tqdm.write('Timeout reached') self.saver.save_checkpoint( self.models_dict, self.optimizers_dict, epoch, step, self.options.batch_size, train_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Checkpoint saved') sys.exit(0) # load a checkpoint only on startup, for the next epochs # just iterate over the dataset as usual self.checkpoint = None # save checkpoint after each epoch if (epoch + 1) % 10 == 0: # self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch+1, 0, self.step_count) self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch + 1, 0, self.options.batch_size, None, self.step_count) return # The following methods (with the possible exception of test) have to be implemented in the derived classes def init_fn(self): raise NotImplementedError('You need to provide an _init_fn method') def train_step(self, input_batch): raise NotImplementedError('You need to provide a _train_step method') def train_summaries(self, input_batch): raise NotImplementedError( 'You need to provide a _train_summaries method') def test(self): pass
def main(): args = parser.parse_args() if args.output: output_base = args.output else: output_base = './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, args.gp, 'f'+str(args.fold)]) output_dir = get_outdir(output_base, 'train', exp_name) train_input_root = os.path.join(args.data) batch_size = args.batch_size num_epochs = args.epochs wav_size = (16000,) num_classes = len(dataset.get_labels()) torch.manual_seed(args.seed) model = model_factory.create_model( args.model, in_chs=1, pretrained=args.pretrained, num_classes=num_classes, drop_rate=args.drop, global_pool=args.gp, checkpoint_path=args.initial_checkpoint) #model.reset_classifier(num_classes=num_classes) dataset_train = dataset.CommandsDataset( root=train_input_root, mode='train', fold=args.fold, wav_size=wav_size, format='spectrogram', ) loader_train = data.DataLoader( dataset_train, batch_size=batch_size, pin_memory=True, shuffle=True, num_workers=args.workers ) dataset_eval = dataset.CommandsDataset( root=train_input_root, mode='validate', fold=args.fold, wav_size=wav_size, format='spectrogram', ) loader_eval = data.DataLoader( dataset_eval, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=args.workers ) train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = validate_loss_fn.cuda() opt_params = list(model.parameters()) if args.opt.lower() == 'sgd': optimizer = optim.SGD( opt_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) elif args.opt.lower() == 'adam': optimizer = optim.Adam( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'nadam': optimizer = nadam.Nadam( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'adadelta': optimizer = optim.Adadelta( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'rmsprop': optimizer = optim.RMSprop( opt_params, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=args.weight_decay) else: assert False and "Invalid optimizer" del opt_params if not args.decay_epochs: print('No decay epoch set, using plateau scheduler.') lr_scheduler = ReduceLROnPlateau(optimizer, patience=10) else: lr_scheduler = None # optionally resume from a checkpoint start_epoch = 0 if args.start_epoch is None else args.start_epoch if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if 'args' in checkpoint: print(checkpoint['args']) new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): if k.startswith('module'): name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) if 'loss' in checkpoint: train_loss_fn.load_state_dict(checkpoint['loss']) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch else: model.load_state_dict(checkpoint) else: print("=> no checkpoint found at '{}'".format(args.resume)) exit(1) saver = CheckpointSaver(checkpoint_dir=output_dir) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() # Optional fine-tune of only the final classifier weights for specified number of epochs (or part of) if not args.resume and args.ft_epochs > 0.: if isinstance(model, torch.nn.DataParallel): classifier_params = model.module.get_classifier().parameters() else: classifier_params = model.get_classifier().parameters() if args.opt.lower() == 'adam': finetune_optimizer = optim.Adam( classifier_params, lr=args.ft_lr, weight_decay=args.weight_decay) else: finetune_optimizer = optim.SGD( classifier_params, lr=args.ft_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) finetune_epochs_int = int(np.ceil(args.ft_epochs)) finetune_final_batches = int(np.ceil((1 - (finetune_epochs_int - args.ft_epochs)) * len(loader_train))) print(finetune_epochs_int, finetune_final_batches) for fepoch in range(0, finetune_epochs_int): if fepoch == finetune_epochs_int - 1 and finetune_final_batches: batch_limit = finetune_final_batches else: batch_limit = 0 train_epoch( fepoch, model, loader_train, finetune_optimizer, train_loss_fn, args, output_dir=output_dir, batch_limit=batch_limit) best_loss = None try: for epoch in range(start_epoch, num_epochs): if args.decay_epochs: adjust_learning_rate( optimizer, epoch, initial_lr=args.lr, decay_rate=args.decay_rate, decay_epochs=args.decay_epochs) train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, saver=saver, output_dir=output_dir) # save a recovery in case validation blows up saver.save_recovery({ 'epoch': epoch + 1, 'arch': args.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': train_loss_fn.state_dict(), 'args': args, 'gp': args.gp, }, epoch=epoch + 1, batch_idx=0) step = epoch * len(loader_train) eval_metrics = validate( step, model, loader_eval, validate_loss_fn, args, output_dir=output_dir) if lr_scheduler is not None: lr_scheduler.step(eval_metrics['eval_loss']) rowd = OrderedDict(epoch=epoch) rowd.update(train_metrics) rowd.update(eval_metrics) with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if best_loss is None: # first iteration (epoch == 1 can't be used) dw.writeheader() dw.writerow(rowd) # save proper checkpoint with eval metric best_loss = saver.save_checkpoint({ 'epoch': epoch + 1, 'arch': args.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'args': args, 'gp': args.gp, }, epoch=epoch + 1, metric=eval_metrics['eval_loss']) except KeyboardInterrupt: pass print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
class BaseTrainer(object): """Base class for Trainer objects. Takes care of checkpointing/logging/resuming training. """ def __init__(self, options): self.options = options if options.multiprocessing_distributed: self.device = torch.device('cuda', options.gpu) else: self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self.saver = CheckpointSaver(save_dir=options.checkpoint_dir, overwrite=options.overwrite) if options.rank == 0: self.summary_writer = SummaryWriter(self.options.summary_dir) self.init_fn() self.checkpoint = None if options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count'] if self.checkpoint is not None: self.checkpoint_batch_idx = self.checkpoint['batch_idx'] else: self.checkpoint_batch_idx = 0 self.best_performance = float('inf') def load_pretrained(self, checkpoint_file=None): """Load a pretrained checkpoint. This is different from resuming training using --resume. """ if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file) for model in self.models_dict: if model in checkpoint: self.models_dict[model].load_state_dict(checkpoint[model], strict=True) print(f'Checkpoint {model} loaded') def move_dict_to_device(self, dict, device, tensor2float=False): for k, v in dict.items(): if isinstance(v, torch.Tensor): if tensor2float: dict[k] = v.float().to(device) else: dict[k] = v.to(device) # The following methods (with the possible exception of test) have to be implemented in the derived classes def train(self, epoch): raise NotImplementedError('You need to provide an train method') def init_fn(self): raise NotImplementedError('You need to provide an _init_fn method') def train_step(self, input_batch): raise NotImplementedError('You need to provide a _train_step method') def train_summaries(self, input_batch): raise NotImplementedError( 'You need to provide a _train_summaries method') def visualize(self, input_batch): raise NotImplementedError('You need to provide a visualize method') def validate(self): pass def test(self): pass def evaluate(self): pass def fit(self): # Run training for num_epochs epochs for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), total=self.options.num_epochs, initial=self.epoch_count): self.epoch_count = epoch self.train(epoch) return
class BaseTrainer(object): """ Base class for Trainer objects. Takes care of checkpointing/logging/resuming training. """ def __init__(self, options): self.options = options self.endtime = time.time() + self.options.time_to_run self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self.init_fn() self.saver = CheckpointSaver(save_dir=options.checkpoint_dir) self.summary_writer = SummaryWriter(self.options.summary_dir) self.checkpoint = None if self.options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint(self.models_dict, self.optimizers_dict, checkpoint_file=self.options.checkpoint) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count'] def load_pretrained(self, checkpoint_file=None): """Load a pretrained checkpoint. This is different from resuming training using --resume. """ if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file) for model in self.models_dict: if model in checkpoint: self.models_dict[model].load_state_dict(checkpoint[model]) print('Checkpoint loaded') def train(self): """Training process.""" # Run training for num_epochs epochs for epoch in range(self.epoch_count, self.options.num_epochs): # Create new DataLoader every epoch and (possibly) resume from an arbitrary step inside an epoch train_data_loader = CheckpointDataLoader(self.train_ds, checkpoint=self.checkpoint, batch_size=self.options.batch_size, num_workers=self.options.num_workers, pin_memory=self.options.pin_memory, shuffle=self.options.shuffle_train) # Iterate over all batches in an epoch for step, batch in enumerate(tqdm(train_data_loader, desc='Epoch ' + str(epoch), total=len(self.train_ds) // self.options.batch_size, initial=train_data_loader.checkpoint_batch_idx), train_data_loader.checkpoint_batch_idx): if time.time() < self.endtime: batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k,v in batch.items()} out = self.train_step(batch) self.step_count += 1 # Tensorboard logging every summary_steps steps if self.step_count % self.options.summary_steps == 0: self.train_summaries(batch, *out) # Save checkpoint every checkpoint_steps steps if self.step_count % self.options.checkpoint_steps == 0: self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch, step+1, self.options.batch_size, train_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Checkpoint saved') # Run validation every test_steps steps if self.step_count % self.options.test_steps == 0: self.test() else: tqdm.write('Timeout reached') self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch, step, self.options.batch_size, train_data_loader.sampler.dataset_perm, self.step_count) tqdm.write('Checkpoint saved') sys.exit(0) # load a checkpoint only on startup, for the next epochs # just iterate over the dataset as usual self.checkpoint=None # save checkpoint after each epoch if (epoch+1) % 10 == 0: # self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch+1, 0, self.step_count) self.saver.save_checkpoint(self.models_dict, self.optimizers_dict, epoch+1, 0, self.options.batch_size, None, self.step_count) return # The following methods (with the possible exception of test) have to be implemented in the derived classes def init_fn(self): raise NotImplementedError('You need to provide an _init_fn method') def train_step(self, input_batch): raise NotImplementedError('You need to provide a _train_step method') def train_summaries(self, input_batch): raise NotImplementedError('You need to provide a _train_summaries method') def test(self): pass def error_adaptive_weight(self, fit_joint_error): weight = (1 - 10 * fit_joint_error) weight[weight <= 0] = 0 return weight def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d, weight=None): """ Compute 2D reprojection loss on the keypoints. The loss is weighted by the weight The available keypoints are different for each dataset. """ if gt_keypoints_2d.shape[2] == 3: conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() else: conf = 1 if weight is not None: weight = weight[:, None, None] conf = conf * weight loss = (conf * self.criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() return loss def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, weight=None): """ Compute 3D keypoint loss for the examples that 3D keypoint annotations are available. The loss is weighted by the weight """ if gt_keypoints_3d.shape[2] == 3: tmp = gt_keypoints_3d.new_ones(gt_keypoints_3d.shape[0], gt_keypoints_3d.shape[1], 1) gt_keypoints_3d = torch.cat((gt_keypoints_3d, tmp), dim=2) conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] conf = conf[has_pose_3d == 1] if weight is not None: weight = weight[has_pose_3d == 1, None, None] conf = conf * weight pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] if len(gt_keypoints_3d) > 0: # Align the origin of the first 24 keypoints with the pelvis. gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2 pred_pelvis = (pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2 gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] # # Align the origin of the first 24 keypoints with the pelvis. # gt_pelvis = (gt_keypoints_3d[:, 2, :] + gt_keypoints_3d[:, 3, :]) / 2 # pred_pelvis = (pred_keypoints_3d[:, 2, :] + pred_keypoints_3d[:, 3, :]) / 2 # gt_keypoints_3d[:, :24, :] = gt_keypoints_3d[:, :24, :] - gt_pelvis[:, None, :] # pred_keypoints_3d[:, :24, :] = pred_keypoints_3d[:, :24, :] - pred_pelvis[:, None, :] # # # Align the origin of the 24 SMPL keypoints with the root joint. # gt_root_joint = gt_keypoints_3d[:, 24] # pred_root_joint = pred_keypoints_3d[:, 24] # gt_keypoints_3d[:, 24:, :] = gt_keypoints_3d[:, 24:, :] - gt_root_joint[:, None, :] # pred_keypoints_3d[:, 24:, :] = pred_keypoints_3d[:, 24:, :] - pred_root_joint[:, None, :] return (conf * self.criterion_keypoints_3d(pred_keypoints_3d, gt_keypoints_3d)).mean() else: return torch.FloatTensor(1).fill_(0.).to(self.device) def smpl_keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, weight=None): """ Compute 3D SMPL keypoint loss for the examples that 3D keypoint annotations are available. The loss is weighted by the weight """ if gt_keypoints_3d.shape[2] == 3: tmp = gt_keypoints_3d.new_ones(gt_keypoints_3d.shape[0], gt_keypoints_3d.shape[1], 1) gt_keypoints_3d = torch.cat((gt_keypoints_3d, tmp), dim=2) conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] conf = conf[has_pose_3d == 1] if weight is not None: weight = weight[has_pose_3d == 1, None, None] conf = conf * weight pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] if len(gt_keypoints_3d) > 0: gt_root_joint = gt_keypoints_3d[:, 0, :] pred_root_joint = pred_keypoints_3d[:, 0, :] gt_keypoints_3d = gt_keypoints_3d - gt_root_joint[:, None, :] pred_keypoints_3d = pred_keypoints_3d - pred_root_joint[:, None, :] return (conf * self.criterion_keypoints_3d(pred_keypoints_3d, gt_keypoints_3d)).mean() else: return torch.FloatTensor(1).fill_(0.).to(self.device) def shape_loss(self, pred_vertices, gt_vertices, has_smpl, weight=None): """Compute per-vertex loss on the shape for the examples that SMPL annotations are available.""" pred_vertices_with_shape = pred_vertices[has_smpl == 1] gt_vertices_with_shape = gt_vertices[has_smpl == 1] if weight is not None: weight = weight[has_smpl == 1, None, None] else: weight = 1 if len(gt_vertices_with_shape) > 0: loss = self.criterion_shape(pred_vertices_with_shape, gt_vertices_with_shape) loss = (loss * weight).mean() return loss else: return torch.FloatTensor(1).fill_(0.).to(self.device) def uv_loss(self, pred_uv_map, gt_uv_map, has_smpl, weight=None): # self.uv_mask = self.uv_mask.to(pred_uv_map.device) self.uv_weight = self.uv_weight.to(pred_uv_map.device).type(pred_uv_map.dtype) max = self.uv_weight.max() pred_uv_map_shape = pred_uv_map[has_smpl == 1] gt_uv_map_with_shape = gt_uv_map[has_smpl == 1] if len(gt_uv_map_with_shape) > 0: # return self.criterion_uv(pred_uv_map_shape * self.uv_mask, gt_uv_map_with_shape * self.uv_mask) if weight is not None: ada_weight = weight[has_smpl > 0, None, None, None] else: ada_weight = 1.0 loss = self.criterion_uv(pred_uv_map_shape * self.uv_weight, gt_uv_map_with_shape * self.uv_weight) loss = (loss * ada_weight).mean() return loss else: # return torch.FloatTensor(1).fill_(0.).to(self.device) return torch.tensor(0.0, dtype=pred_uv_map.dtype, device=self.device) def tv_loss(self, uv_map): self.uv_weight = self.uv_weight.to(uv_map.device) tv = torch.abs(uv_map[:,0:-1, 0:-1, :] - uv_map[:,0:-1, 1:, :]) \ + torch.abs(uv_map[:,0:-1, 0:-1, :] - uv_map[:,1:, 0:-1, :]) return torch.sum(tv) / self.tv_factor # return torch.sum(tv * self.uv_weight[:, 0:-1, 0:-1]) / self.tv_factor def dp_loss(self, pred_dp, gt_dp, has_dp, weight=None): dtype = pred_dp.dtype pred_dp_shape = pred_dp[[has_dp > 0]] gt_dp_shape = gt_dp[[has_dp > 0]] if len(gt_dp_shape) > 0: gt_mask_shape = (gt_dp_shape[:, 0].unsqueeze(1) > 0).type(dtype) gt_uv_shape = gt_dp_shape[:, 1:] pred_mask_shape = pred_dp_shape[:, 0].unsqueeze(1) pred_uv_shape = pred_dp_shape[:, 1:] pred_mask_shape = F.interpolate(pred_mask_shape, [gt_dp.shape[2], gt_dp.shape[3]], mode='bilinear') pred_uv_shape = F.interpolate(pred_uv_shape, [gt_dp.shape[2], gt_dp.shape[3]], mode='nearest') if weight is not None: weight = weight[has_dp > 0, None, None, None] else: weight = 1.0 pred_mask_shape = pred_mask_shape.clamp(min=0.0, max=1.0) loss_mask = torch.nn.BCELoss(reduction='none')(pred_mask_shape, gt_mask_shape) loss_mask = (loss_mask * weight).mean() gt_uv_weight = (gt_uv_shape.abs().max(dim=1, keepdim=True)[0] > 0).type(dtype) weight_ratio = (gt_uv_weight.mean(dim=-1).mean(dim=-1)[:, :, None, None] + 1e-8) gt_uv_weight = gt_uv_weight / weight_ratio # normalized the weight according to mask area loss_uv = self.criterion_uv(gt_uv_weight * pred_uv_shape, gt_uv_weight * gt_uv_shape) loss_uv = (loss_uv * weight).mean() return loss_mask, loss_uv else: return pred_dp.sum() * 0, pred_dp.sum() * 0 def consistent_loss(self, dp, uv_map, camera, weight=None): tmp = torch.arange(0, dp.shape[-1], 1, dtype=dp.dtype, device=dp.device) / (dp.shape[-1] -1) tmp = tmp * 2 - 1 loc_y, loc_x = torch.meshgrid(tmp, tmp) loc = torch.stack((loc_x, loc_y), dim=0).expand(dp.shape[0], -1, -1, -1) dp_mask = (dp[:, 0] > 0.5).float().unsqueeze(1) loc = dp_mask * loc dp_tmp = dp_mask * (dp[:, 1:] * 2 - 1) '''uv_map need to be transfered to img coordinate first''' uv_map = uv_map[:, :, :, :-1] camera = camera.view(-1, 1, 1, 3) uv_map = uv_map + camera[:, :, :, 1:] # trans uv_map = uv_map * camera[:, :, :, 0].unsqueeze(-1) # scale warp_loc = F.grid_sample(uv_map.permute(0, 3, 1, 2), dp_tmp.permute(0, 2, 3, 1))[:, :2] warp_loc = warp_loc * dp_mask if weight is not None: weight = weight[:, None, None, None] dp_mask = dp_mask * weight loss_con = torch.nn.MSELoss()(warp_loc * dp_mask, loc * dp_mask) return loss_con
def train(args, logger, tb_writer): logger.info('Args: {}'.format(json.dumps(vars(args), indent=4, sort_keys=True))) if args.local_rank in [-1, 0]: with open(os.path.join(args.save_dir, 'args.yaml'), 'w') as file: yaml.safe_dump(vars(args), file, sort_keys=False) device_id = args.local_rank if args.local_rank != -1 else 0 device = torch.device('cuda', device_id) logger.warning(f'Using GPU {args.local_rank}.') world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1 logger.info(f'Total number of GPUs used: {world_size}.') effective_batch_size = args.batch_size * world_size * args.accumulation_steps logger.info(f'Effective batch size: {effective_batch_size}.') num_train_samples_per_epoch, num_dev_samples, num_unique_train_epochs = get_data_sizes(data_dir=args.data_dir, num_epochs=args.num_epochs, logger=logger) num_optimization_steps = sum(num_train_samples_per_epoch) // world_size // args.batch_size // \ args.accumulation_steps if args.max_steps > 0: num_optimization_steps = min(num_optimization_steps, args.max_steps) logger.info(f'Total number of optimization steps: {num_optimization_steps}.') # Set random seed logger.info(f'Using random seed {args.seed}.') random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Get model if args.local_rank not in [-1, 0]: torch.distributed.barrier() logger.info(f'Loading model {args.model} for task {args.task}...') model = ModelRegistry.get_model(args.task).from_pretrained(args.model) if args.local_rank in [-1, 0]: with open(os.path.join(args.save_dir, 'config.json'), 'w') as file: json.dump(model.config.__dict__, file) if args.local_rank == 0: torch.distributed.barrier() model.to(device) # Get optimizer logger.info('Creating optimizer...') parameter_groups = get_parameter_groups(model) optimizer = AdamW(parameter_groups, lr=args.learning_rate, weight_decay=args.weight_decay, eps=1e-8) scheduler = get_lr_scheduler(optimizer, num_steps=num_optimization_steps, warmup_proportion=args.warmup_proportion) if args.amp: amp.register_half_function(torch, 'einsum') model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) if args.local_rank != -1: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Get dev data loader dev_data_file = os.path.join(args.data_dir, f'dev.jsonl.gz') logger.info(f'Creating dev dataset from {dev_data_file}...') dev_dataset = DatasetRegistry.get_dataset(args.task)(data_file=dev_data_file, data_size=num_dev_samples, local_rank=-1) dev_loader = DataLoader(dev_dataset, batch_size=2 * args.batch_size, num_workers=1, collate_fn=dev_dataset.collate_fn) # Get evaluator evaluator = EvaluatorRegistry.get_evaluator(args.task)(data_loader=dev_loader, logger=logger, tb_writer=tb_writer, device=device, world_size=world_size, args=args) # Get saver saver = CheckpointSaver(save_dir=args.save_dir, max_checkpoints=args.max_checkpoints, primary_metric=evaluator.primary_metric, maximize_metric=evaluator.maximize_metric, logger=logger) global_step = 0 samples_processed = 0 # Train logger.info('Training...') samples_till_eval = args.eval_every for epoch in range(1, args.num_epochs + 1): # Get train data loader for current epoch train_data_file_num = ((epoch - 1) % num_unique_train_epochs) + 1 train_data_file = os.path.join(args.data_dir, f'epoch_{train_data_file_num}.jsonl.gz') logger.info(f'Creating training dataset from {train_data_file}...') train_dataset = DatasetRegistry.get_dataset(args.task)(train_data_file, data_size=num_train_samples_per_epoch[epoch - 1], local_rank=args.local_rank, world_size=world_size) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=1, collate_fn=train_dataset.collate_fn) logger.info(f'Starting epoch {epoch}...') model.train() model.zero_grad() loss_values = defaultdict(float) samples_till_end = (num_optimization_steps - global_step) * effective_batch_size samples_in_cur_epoch = min([len(train_loader.dataset), samples_till_end]) disable_progress_bar = (args.local_rank not in [-1, 0]) with tqdm(total=samples_in_cur_epoch, disable=disable_progress_bar) as progress_bar: for step, batch in enumerate(train_loader, 1): batch = {name: tensor.to(device) for name, tensor in batch.items()} current_batch_size = batch['input_ids'].shape[0] outputs = model(**batch) loss, current_loss_values = outputs[:2] loss = loss / args.accumulation_steps for name, value in current_loss_values.items(): loss_values[name] += value / args.accumulation_steps if args.amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() samples_processed += current_batch_size * world_size samples_till_eval -= current_batch_size * world_size progress_bar.update(current_batch_size * world_size) if step % args.accumulation_steps == 0: current_lr = scheduler.get_last_lr()[0] if args.amp: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0) else: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() model.zero_grad() global_step += 1 # Log info progress_bar.set_postfix(epoch=epoch, step=global_step, lr=current_lr, **loss_values) if args.local_rank in [-1, 0]: tb_writer.add_scalar('train/LR', current_lr, global_step) for name, value in loss_values.items(): tb_writer.add_scalar(f'train/{name}', value, global_step) loss_values = {name: 0 for name in loss_values} if global_step == args.max_steps: logger.info('Reached maximum number of optimization steps.') break if samples_till_eval <= 0: samples_till_eval = args.eval_every eval_results = evaluator.evaluate(model, global_step) if args.local_rank in [-1, 0]: saver.save(model, global_step, eval_results) if not args.do_not_eval_after_epoch: eval_results = evaluator.evaluate(model, global_step) if args.local_rank in [-1, 0]: saver.save(model, global_step, eval_results)
class KeypointTester(): def __init__(self, options): self.options = options self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') test_transform_list = [] if self.options.max_scale > 1: test_transform_list.append( RandomRescaleBB(1.0, self.options.max_scale)) test_transform_list.append( CropAndResize(out_size=(self.options.crop_size, self.options.crop_size))) test_transform_list.append( LocsToHeatmaps(out_size=(self.options.heatmap_size, self.options.heatmap_size))) test_transform_list.append(ToTensor()) test_transform_list.append(Normalize()) self.test_ds = RctaDataset( root_dir=self.options.dataset_dir, is_train=False, transform=transforms.Compose(test_transform_list)) self.model = StackedHourglass(self.options.num_keypoints).to( self.device) # Only create optimizer because it is required to restore from checkpoint self.optimizer = torch.optim.RMSprop(params=self.model.parameters(), lr=0, momentum=0, weight_decay=0) self.models_dict = {'stacked_hg': self.model} self.optimizers_dict = {'optimizer': self.optimizer} print("log dir:", options.log_dir) print("checkpoint dir:", options.checkpoint_dir) self.saver = CheckpointSaver(save_dir=options.checkpoint_dir) print("checkpoint:", self.options.checkpoint) self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict, checkpoint_file=self.options.checkpoint) self.criterion = nn.MSELoss().to(self.device) self.pose = Pose2DEval(detection_thresh=self.options.detection_thresh, dist_thresh=self.options.dist_thresh) def test(self): test_data_loader = DataLoader(self.test_ds, batch_size=self.options.test_batch_size, num_workers=self.options.num_workers, pin_memory=self.options.pin_memory, shuffle=False) pcks = [] pcks2 = [] losses = [] distances = [] all_object_classes = {} for tstep, batch in enumerate(tqdm(test_data_loader, desc='Testing')): batch = {k: v.to(self.device) for k, v in batch.items()} object_classes = [] used_keypoints = torch.sum(batch['keypoint_heatmaps'], axis=[2, 3]) for i in range(batch['keypoint_heatmaps'].shape[0]): nonzero = torch.nonzero(used_keypoints[i, :]) limits = (nonzero[0][0].item(), nonzero[-1][0].item()) curr_class = None for c in all_object_classes: if limits[0] <= c and c <= limits[1]: curr_class = c all_object_classes[c]['limits'] = ( min(limits[0], all_object_classes[c]['limits'][0]), max(limits[1], all_object_classes[c]['limits'][1])) if curr_class is None: curr_class = int((limits[0] + limits[1]) // 2) all_object_classes[curr_class] = { 'limits': limits, 'index': len(all_object_classes), 'pcks': [], 'distances': [], 'dir': join(self.options.log_dir, "test_" + self.options.name, "class_{}".format(len(all_object_classes))) } if not isdir(all_object_classes[curr_class]['dir']): makedirs(all_object_classes[curr_class]['dir']) object_classes.append(curr_class) print("object classes:", object_classes) pred_keypoints, loss = self._test_step(batch) print("input images:", batch['image'].shape) denormed_batch = Denormalize()(batch) losses.append(loss.data.cpu().item()) shape = pred_keypoints[-1].shape for i in range(shape[0]): pcks.append( self.pose.pck( batch['keypoint_heatmaps'][i].reshape( 1, shape[1], shape[2], shape[3]), pred_keypoints[-1][i].reshape(1, shape[1], shape[2], shape[3]))) locs = self.pose.heatmaps_to_locs( pred_keypoints[-1][i].reshape(1, shape[1], shape[2], shape[3]), no_thresh=True) gt_locs = self.pose.heatmaps_to_locs( batch['keypoint_heatmaps'][i].reshape( 1, shape[1], shape[2], shape[3])) for k in range(gt_locs.shape[1]): if gt_locs[0][k][0] == 0 and gt_locs[0][k][1] == 0: continue dist = np.sqrt((gt_locs[0][k][0] - locs[0][k][0])**2 + (gt_locs[0][k][1] - locs[0][k][1])**2) all_object_classes[object_classes[i]]['distances'].append( dist) distances.append(dist) all_object_classes[object_classes[i]]['pcks'].append(pcks[-1]) input_image = transforms.ToPILImage()( denormed_batch['image'][i].cpu()) input_image.save( join(all_object_classes[object_classes[i]]['dir'], "input_im_{:04d}.png".format(len(pcks)))) print("before shape:", batch['keypoint_heatmaps'][i].shape) print("before heatmap:", (torch.sum(batch['keypoint_heatmaps'][i], axis=[0]).cpu()).min(), (torch.sum(batch['keypoint_heatmaps'][i], axis=[0]).cpu()).max()) print("input_im:", input_image.size, np.array(input_image).max(), np.array(input_image).min()) gt_heatmap_im = transforms.ToPILImage()(torch.sum( batch['keypoint_heatmaps'][i], axis=[0]).cpu()) print("gt_hearmap:", gt_heatmap_im.size, np.array(gt_heatmap_im).max(), np.array(gt_heatmap_im).min()) print("numpy gt:", np.array(gt_heatmap_im).max(), np.array(gt_heatmap_im).min()) gt_heatmap_with_im = Image.fromarray( np.array(input_image) // 2 + np.array( gt_heatmap_im.resize(input_image.size, Image.ANTIALIAS) ).reshape(input_image.size[0], input_image.size[1], 1) // 2) gt_heatmap_with_im.save( join(all_object_classes[object_classes[i]]['dir'], "gt_im_with_heatmap_{:04d}.png".format(len(pcks)))) print("gt_hearmap with:", gt_heatmap_with_im.size, np.array(gt_heatmap_with_im).max(), np.array(gt_heatmap_with_im).min()) print("before image:", (torch.sum(pred_keypoints[-1][i], axis=[0]).cpu()).min(), (torch.sum(pred_keypoints[-1][i], axis=[0]).cpu()).max(), (torch.sum(pred_keypoints[-1][i], axis=[0]).cpu()).shape) pred_heatmap_im = transforms.ToPILImage()(torch.clamp( torch.sum(pred_keypoints[-1][i], axis=[0]).cpu(), 0.0, 1.0)) pred_heatmap_with_im = Image.fromarray( np.array(input_image) // 2 + np.array( pred_heatmap_im. resize(input_image.size, Image.ANTIALIAS)).reshape( input_image.size[0], input_image.size[1], 1) // 2) pred_heatmap_with_im.save( join(all_object_classes[object_classes[i]]['dir'], "pred_im_with_heatmap_{:04d}.png".format(len(pcks)))) pcks2.append( self.pose.pck(batch['keypoint_heatmaps'], pred_keypoints[-1])) print("heatmaps:", pred_keypoints[0].shape, pred_keypoints[-1].shape) print("pcks:", pcks) print("Means:", np.mean(pcks), "Std error:", np.std(pcks) / np.sqrt(len(pcks))) print("Means2:", np.mean(pcks2), "Std error:", np.std(pcks2) / np.sqrt(len(pcks))) print("means 1 and means 2 should be equal") print("mean loss:", np.mean(losses)) for c in all_object_classes: print( "PCK for class:", c, "Mean:", np.mean(all_object_classes[c]['pcks']), "std error:", np.std(all_object_classes[c]['pcks']) / np.sqrt(len(all_object_classes[c]['pcks']))) print( "Dist for class:", c, "Mean:", np.mean(all_object_classes[c]['distances']), "std error:", np.std(all_object_classes[c]['distances']) / np.sqrt(len(all_object_classes[c]['distances']))) def _test_step(self, input_batch): self.model.eval() images = input_batch['image'] gt_keypoints = input_batch['keypoint_heatmaps'] summed = torch.sum(gt_keypoints, axis=[2, 3]) with torch.no_grad(): pred_keypoints = self.model(images) loss = torch.tensor(0.0, device=self.device) for i in range(len(pred_keypoints)): loss += self.criterion(pred_keypoints[i], gt_keypoints) return pred_keypoints, loss
def inference_structure(pathCkp: str, pathImg: str = None, pathBgImg: str = None): print('If trained locally and renamed the workspace, do not for get to ' 'change the "checkpoint_dir" in config.json. ') # Load configuration with open(pjn(pathCkp, 'config.json'), 'r') as f: options = json.load(f) options = namedtuple('options', options.keys())(**options) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') mesh = Mesh(options, options.num_downsampling) # read SMPL .bj file to get uv coordinates _, smpl_tri_ind, uv_coord, tri_uv_ind = read_Obj(options.smpl_objfile_path) uv_coord[:, 1] = 1 - uv_coord[:, 1] expUV = uv_coord[tri_uv_ind.flatten()] unique, index = np.unique(smpl_tri_ind.flatten(), return_index=True) smpl_verts_uvs = torch.as_tensor(expUV[index, :]).float().to(device) smpl_tri_ind = torch.as_tensor(smpl_tri_ind).to(device) # load average pose and shape and convert to camera coodinate; # avg pose is decided by the image id we use for training (0-11) avgPose_objCoord = np.load(options.MGN_avgPose_path) avgPose_objCoord[:3] = rotationMatrix_to_axisAngle( # for 0,6, front only torch.tensor([[[1, 0, 0], [0, -1, 0], [0, 0, -1]]])) avgPose = \ axisAngle_to_Rot6d( torch.Tensor(avgPose_objCoord[None]).reshape(-1, 3) ).reshape(1, -1).to(device) avgBeta = \ torch.Tensor( np.load(options.MGN_avgBeta_path)[None]).to(device) avgCam = torch.Tensor([1.2755, 0, 0])[None].to(device) # 1.2755 is for our settings # Create model model = frameVIBE(options.smpl_model_path, mesh, avgPose, avgBeta, avgCam, options.num_channels, options.num_layers, smpl_verts_uvs, smpl_tri_ind).to(device) optimizer = torch.optim.Adam(params=list(model.parameters())) models_dict = {options.model: model} optimizers_dict = {'optimizer': optimizer} # Load pretrained model saver = CheckpointSaver(save_dir=options.checkpoint_dir) saver.load_checkpoint(models_dict, optimizers_dict, checkpoint_file=options.checkpoint) # Prepare and preprocess input image pathToObj = '/'.join(pathImg.split('/')[:-2]) cameraIdx = int(pathImg.split('/')[-1].split('_')[0][6:]) with open( pjn(pathToObj, 'rendering/camera%d_boundingbox.txt' % (cameraIdx))) as f: boundbox = literal_eval(f.readline()) IMG_NORM_MEAN = [0.485, 0.456, 0.406] IMG_NORM_STD = [0.229, 0.224, 0.225] normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD) path_to_rendering = '/'.join(pathImg.split('/')[:-1]) cameraPath, lightPath = pathImg.split('/')[-1].split('_')[:2] cameraIdx, _ = int(cameraPath[6:]), int(lightPath[5:]) with open(pjn(path_to_rendering, 'camera%d_boundingbox.txt' % (cameraIdx))) as f: boundbox = literal_eval(f.readline()) img = cv2.imread(pathImg)[:, :, ::-1].astype(np.float32) # prepare background if options.replace_background: if pathBgImg is None: bgimages = [] for subfolder in sorted( glob(pjn(options.bgimg_dir, 'images/validation/*'))): for subsubfolder in sorted(glob(pjn(subfolder, '*'))): if 'room' in subsubfolder: bgimages += sorted(glob(pjn(subsubfolder, '*.jpg'))) bgimg = cv2.imread(bgimages[np.random.randint( 0, len(bgimages))])[:, :, ::-1].astype(np.float32) else: bgimg = cv2.imread(pathBgImg)[:, :, ::-1].astype(np.float32) img = background_replacing(img, bgimg) # augment image center = [(boundbox[0] + boundbox[2]) / 2, (boundbox[1] + boundbox[3]) / 2] scale = max((boundbox[2] - boundbox[0]) / 200, (boundbox[3] - boundbox[1]) / 200) img = torch.Tensor(crop(img, center, scale, [224, 224], rot=0)).permute( 2, 0, 1) / 255 img_in = normalize_img(img) # Inference with torch.no_grad(): # disable grad model.eval() prediction = model( img_in[None].repeat_interleave(options.batch_size, dim=0).to(device), img[None].repeat_interleave(options.batch_size, dim=0).to(device)) return prediction, img_in, options
transform=transforms.Compose([ transforms.Grayscale(), transforms.Resize((opt.img_size, opt.img_size)), transforms.ToTensor(), # TODO normalize according to dataset stats. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])), batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu) # Optimizers optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr) optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr) discriminator_saver = CheckpointSaver(opt.save_dir_name, max_checkpoints=3) generator_saver = CheckpointSaver(opt.save_dir_name, max_checkpoints=3) # ---------- # Training # ---------- batches_done = 0 for epoch in range(opt.n_epochs): # Batch iterator data_iter = iter(dataloader) for i in range(len(data_iter) // opt.n_critic): # Train discriminator for n_critic times for _ in range(opt.n_critic): imgs = data_iter.next()