def train(train_dataloader, model, epoch, loss_func, optimizer, scheduler, training_stats, val_dataloader=None, val_err=[], ignore_step=-1): """ Train the model in steps """ model.train() epoch_steps = math.ceil(len(train_dataloader) / cfg.TRAIN.BATCHSIZE) base_steps = epoch_steps * epoch + ignore_step if ignore_step != -1 else epoch_steps * epoch for i, data in enumerate(train_dataloader): if ignore_step != -1 and i > epoch_steps - ignore_step: return scheduler.step() # decay lr every iteration training_stats.IterTic() out = model(data) losses = loss_func.criterion(out['b_fake_softmax'], out['b_fake_logit'], data, epoch) optimizer.optim(losses) step = base_steps + i + 1 training_stats.UpdateIterStats(losses) training_stats.IterToc() training_stats.LogIterStats(step, epoch, optimizer.optimizer, val_err[0]) # validate the model if step % cfg.TRAIN.VAL_STEP == 0 and step != 0 and val_dataloader is not None: model.eval() val_err[0] = val(val_dataloader, model) # training mode model.train() # save checkpoint if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0: save_ckpt(train_args, step, epoch, model, optimizer.optimizer, scheduler, val_err[0])
def train(train_dataloader, model, epoch, loss_func, optimizer, scheduler, training_stats, val_dataloader=None, val_err=[], ignore_step=-1): """ Train the model in steps """ model.train() epoch_steps = math.ceil(len(train_dataloader) / cfg.TRAIN.BATCHSIZE) base_steps = epoch_steps * epoch + ignore_step if ignore_step != -1 else epoch_steps * epoch for i, data in enumerate(train_dataloader): print("step:", i) if ignore_step != -1 and i > epoch_steps - ignore_step: return scheduler.step() # decay lr every iteration training_stats.IterTic() out = model(data) # image=data['A'][0] # img=torchvision.transforms.ToPILImage()(image) # img.show() # gt=data['B'][0] # gt=torchvision.transforms.ToPILImage()(gt) # gt.show() # # depth=bins_to_depth(out['b_fake_softmax'])[0] # depth_img=torchvision.transforms.ToPILImage()(depth) # depth_img.show() if train_args.refine==True: losses = loss_func.criterion(out['refined_depth'], data) else: losses = loss_func.criterion(out['b_fake_softmax'], out['b_fake_logit'], data, epoch) optimizer.optim(losses) step = base_steps + i + 1 training_stats.UpdateIterStats(losses) training_stats.IterToc() training_stats.LogIterStats(step, epoch, optimizer.optimizer, val_err[0]) # validate the model if step % cfg.TRAIN.VAL_STEP == 0 and step != 0 and val_dataloader is not None: model.eval() val_err[0] = val(val_dataloader, model) # training mode model.train() # save checkpoint if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0: save_ckpt(train_args, step, epoch, model, optimizer.optimizer, scheduler, val_err[0]) break
data = next(dataloader_iterator) except: dataloader_iterator = iter(train_dataloader) data = next(dataloader_iterator) training_stats.IterTic() out = model(data) losses = loss_func.criterion(out['b_fake_softmax'], out['b_fake_logit'], data) optimizer.optim(losses) training_stats.UpdateIterStats(losses) training_stats.IterToc() training_stats.LogIterStats(step, 0, optimizer.optimizer, val_err[0]) # validate the model if (step+1) % cfg.TRAIN.VAL_STEP == 0 and val_dataloader is not None and step != 0: model.eval() val_err[0] = val(val_dataloader, model) # training mode model.train() # save checkpoint if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0: save_ckpt(train_args, step, epoch, model, optimizer.optimizer, scheduler, val_err[0]) except (RuntimeError, KeyboardInterrupt): logger.info('Save ckpt on exception ...') stack_trace = traceback.format_exc() print(stack_trace) finally: if train_args.use_tfboard: tblogger.close()
def do_train(train_dataloader, val_dataloader, train_args, model, save_to_disk, scheduler, optimizer, val_err, logger, tblogger=None): print(cfg.TRAIN.BASE_LR) # training status for logging if save_to_disk: training_stats = TrainingStats( train_args, cfg.TRAIN.LOG_INTERVAL, tblogger if train_args.use_tfboard else None) dataloader_iterator = iter(train_dataloader) start_step = train_args.start_step total_iters = cfg.TRAIN.MAX_ITER train_datasize = train_dataloader.batch_sampler.sampler.total_sampled_size pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later() tmp_i = 0 try: for step in range(start_step, total_iters): if step % train_args.sample_ratio_steps == 0 and step != 0: sample_ratio = increase_sample_ratio_steps( step, base_ratio=train_args.sample_start_ratio, step_size=train_args.sample_ratio_steps) train_dataloader, curr_sample_size = MultipleDataLoaderDistributed( train_args, sample_ratio=sample_ratio) dataloader_iterator = iter(train_dataloader) logger.info( 'Sample ratio: %02f, current sampled datasize: %d' % (sample_ratio, np.sum(curr_sample_size))) epoch = int(step * train_args.batchsize * train_args.world_size / train_datasize) if save_to_disk: training_stats.IterTic() # get the next data batch try: data = next(dataloader_iterator) except: dataloader_iterator = iter(train_dataloader) data = next(dataloader_iterator) out = model(data) losses_dict = out['losses'] optimizer.optim(losses_dict) #################Check data loading###################### # tmp_path_base = '/home/yvan/DeepLearning/Depth/DiverseDepth-github/DiverseDepth/datasets/x/' # rgb = data['A'][1, ...].permute(1, 2, 0).squeeze() # rgb =rgb * torch.tensor(cfg.DATASET.RGB_PIXEL_VARS)[None, None, :] + torch.tensor(cfg.DATASET.RGB_PIXEL_MEANS)[None, None, :] # rgb = rgb * 255 # rgb = rgb.cpu().numpy().astype(np.uint8) # depth = (data['B'][1, ...].squeeze().cpu().numpy()*1000) # depth[depth<0] = 0 # depth = depth.astype(np.uint16) # plt.imsave(tmp_path_base+'%04d_r.jpg' % tmp_i, rgb) # plt.imsave(tmp_path_base+'%04d_d.png' % tmp_i, depth, cmap='rainbow') # tmp_i +=1 ######################################################### # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(losses_dict) scheduler.step() if save_to_disk: training_stats.UpdateIterStats(loss_dict_reduced) training_stats.IterToc() training_stats.LogIterStats(step, epoch, optimizer.optimizer, val_err[0]) # validate the model if step % cfg.TRAIN.VAL_STEP == 0 and val_dataloader is not None and step != 0: model.eval() val_err[0] = val(val_dataloader, model) # training mode model.train() # save checkpoint if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0 and save_to_disk: save_ckpt(train_args, step, epoch, model, optimizer.optimizer, scheduler, val_err[0]) except (RuntimeError, KeyboardInterrupt): stack_trace = traceback.format_exc() print(stack_trace) finally: if train_args.use_tfboard and main_process(train_args): tblogger.close()
def do_train(train_dataloader, val_dataloader, train_args, model, save_to_disk, scheduler, optimizer, val_err, logger, tblogger=None): # training status for logging if save_to_disk: training_stats = TrainingStats( train_args, cfg.TRAIN.LOG_INTERVAL, tblogger if train_args.use_tfboard else None) dataloader_iterator = iter(train_dataloader) start_step = train_args.start_step total_iters = cfg.TRAIN.MAX_ITER train_datasize = len(train_dataloader) pytorch_1_1_0_or_later = is_pytorch_1_1_0_or_later() try: for step in range(start_step, total_iters): if step % train_args.sample_ratio_steps == 0 and step != 0: sample_ratio = increase_sample_ratio_steps( step, base_ratio=train_args.sample_start_ratio, step_size=train_args.sample_ratio_steps) train_dataloader = MultipleDataLoaderDistributed( train_args, sample_ratio=sample_ratio) dataloader_iterator = iter(train_dataloader) logger.info( 'Sample ratio: %02f, current sampled datasize: %d' % (sample_ratio, np.sum(train_dataloader.curr_sample_size))) # in pytorch >= 1.1.0, scheduler.step() should be run after optimizer.step() if not pytorch_1_1_0_or_later: scheduler.step() epoch = int(step * train_args.batchsize / train_datasize) if save_to_disk: training_stats.IterTic() # get the next data batch try: data = next(dataloader_iterator) except: dataloader_iterator = iter(train_dataloader) data = next(dataloader_iterator) out = model(data) losses_dict = out['losses'] optimizer.optim(losses_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_loss_dict(losses_dict) if pytorch_1_1_0_or_later: scheduler.step() if save_to_disk: training_stats.UpdateIterStats(loss_dict_reduced) training_stats.IterToc() training_stats.LogIterStats(step, epoch, optimizer.optimizer, val_err[0]) # validate the model if step % cfg.TRAIN.VAL_STEP == 0 and val_dataloader is not None and step != 0: model.eval() val_err[0] = val(val_dataloader, model) # training mode model.train() # save checkpoint if step % cfg.TRAIN.SNAPSHOT_ITERS == 0 and step != 0 and save_to_disk: save_ckpt(train_args, step, epoch, model, optimizer.optimizer, scheduler, val_err[0]) except (RuntimeError, KeyboardInterrupt): stack_trace = traceback.format_exc() print(stack_trace) finally: if train_args.use_tfboard and get_rank() == 0: tblogger.close()