def train(params, args, world_rank, local_rank): #logging info logging.info('rank {:d}, begin data loader init (local rank {:d})'.format( world_rank, local_rank)) # set device device = torch.device("cuda:{}".format(local_rank)) # data loader pipe = dl.DaliPipeline(params, num_threads=params.num_data_workers, device_id=device.index) pipe.build() train_data_loader = DALIGenericIterator([pipe], ['inp', 'tar'], params.Nsamples, auto_reset=True) logging.info('rank %d, data loader initialized' % world_rank) model = UNet.UNet(params) model.to(device) if not args.resuming: model.apply(model.get_weights_function(params.weight_init)) optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr) #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision if params.distributed: model = DDP(model, device_ids=[local_rank]) # loss criterion = UNet.CosmoLoss(params.LAMBDA_2) # amp stuff if args.enable_amp: gscaler = amp.GradScaler() iters = 0 startEpoch = 0 checkpoint = None if args.resuming: if world_rank == 0: logging.info("Loading checkpoint %s" % params.checkpoint_path) checkpoint = torch.load(params.checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state']) iters = checkpoint['iters'] startEpoch = checkpoint['epoch'] + 1 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if world_rank == 0: logging.info(model) logging.info("Starting Training Loop...") for epoch in range(startEpoch, startEpoch + params.num_epochs): start = time.time() nsteps = 0 fw_time = 0. bw_time = 0. log_time = 0. model.train() step_time = time.time() #for i, data in enumerate(train_data_loader, 0): with torch.autograd.profiler.emit_nvtx(): for data in train_data_loader: iters += 1 #adjust_LR(optimizer, params, iters) inp = data[0]["inp"] tar = data[0]["tar"] if not args.io_only: torch.cuda.nvtx.range_push("cosmo3D:forward") # fw pass fw_time -= time.time() optimizer.zero_grad() with amp.autocast(args.enable_amp): gen = model(inp) loss = criterion(gen, tar) fw_time += time.time() torch.cuda.nvtx.range_pop() # bw pass torch.cuda.nvtx.range_push("cosmo3D:backward") bw_time -= time.time() if args.enable_amp: gscaler.scale(loss).backward() gscaler.step(optimizer) gscaler.update() else: loss.backward() optimizer.step() bw_time += time.time() torch.cuda.nvtx.range_pop() nsteps += 1 # epoch done dist.barrier() step_time = (time.time() - step_time) / float(nsteps) fw_time /= float(nsteps) bw_time /= float(nsteps) io_time = max([step_time - fw_time - bw_time, 0]) iters_per_sec = 1. / step_time end = time.time() if world_rank == 0: logging.info('Time taken for epoch {} is {} sec'.format( epoch + 1, end - start)) logging.info( 'total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}' .format(step_time, fw_time, bw_time, io_time, iters_per_sec, log_time)) ## Output training stats #model.eval() #if world_rank==0: # log_start = time.time() # gens = [] # tars = [] # with torch.no_grad(): # for i, data in enumerate(train_data_loader, 0): # if i>=16: # break # #inp, tar = map(lambda x: x.to(device), data) # inp, tar = data # gen = model(inp) # gens.append(gen.detach().cpu().numpy()) # tars.append(tar.detach().cpu().numpy()) # gens = np.concatenate(gens, axis=0) # tars = np.concatenate(tars, axis=0) # # # Scalars # args.tboard_writer.add_scalar('G_loss', loss.item(), iters) # # # Plots # fig, chi, L1score = meanL1(gens, tars) # args.tboard_writer.add_figure('pixhist', fig, iters, close=True) # args.tboard_writer.add_scalar('Metrics/chi', chi, iters) # args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters) # args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters) # args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters) # args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters) # args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters) # # fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1]) # args.tboard_writer.add_figure('genimg', fig, iters, close=True) # log_end = time.time() # log_time += log_end - log_start # # Save checkpoint # torch.save({'iters': iters, 'epoch':epoch, 'model_state': model.state_dict(), # 'optimizer_state_dict': optimizer.state_dict()}, params.checkpoint_path) #end = time.time() #if world_rank==0: # logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, end-start)) # logging.info('total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}' # .format(step_time, fw_time, bw_time, io_time, iters_per_sec, log_time)) # finalize dist.barrier()
def train(params, args, world_rank, local_rank): #logging info logging.info('rank {:d}, begin data loader init (local rank {:d})'.format( world_rank, local_rank)) # set device device = torch.device("cuda:{}".format(local_rank)) # data loader pipe = dl.DaliPipeline(params, num_threads=params.num_data_workers, device_id=device.index) pipe.build() train_data_loader = DALIGenericIterator([pipe], ['inp', 'tar'], params.Nsamples, auto_reset=True) logging.info('rank %d, data loader initialized' % world_rank) model = UNet.UNet(params).to(device) if not args.resuming: model.apply(model.get_weights_function(params.weight_init)) optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr) #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision if params.distributed: model = DDP(model, device_ids=[device.index], output_device=device.index) # loss criterion = UNet.CosmoLoss(params.LAMBDA_2) # amp stuff if args.enable_amp: gscaler = amp.GradScaler() iters = 0 startEpoch = 0 checkpoint = None if args.resuming: if world_rank == 0: logging.info("Loading checkpoint %s" % params.checkpoint_path) checkpoint = torch.load(params.checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state']) iters = checkpoint['iters'] startEpoch = checkpoint['epoch'] + 1 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if world_rank == 0: logging.info(model) logging.info("Starting Training Loop...") with torch.autograd.profiler.emit_nvtx(): for epoch in range(startEpoch, startEpoch + params.num_epochs): if args.global_timing: dist.barrier() start = time.time() epoch_step = 0 tr_time = 0. fw_time = 0. bw_time = 0. log_time = 0. model.train() for data in train_data_loader: torch.cuda.nvtx.range_push("cosmo3D:step {}".format(iters)) tr_start = time.time() adjust_LR(optimizer, params, iters) # fetch data inp = data[0]["inp"] tar = data[0]["tar"] if not args.io_only: torch.cuda.nvtx.range_push( "cosmo3D:forward {}".format(iters)) # fw pass fw_time -= time.time() optimizer.zero_grad() with amp.autocast(args.enable_amp): gen = model(inp) loss = criterion(gen, tar) fw_time += time.time() torch.cuda.nvtx.range_pop() # bw pass torch.cuda.nvtx.range_push( "cosmo3D:backward {}".format(iters)) bw_time -= time.time() if args.enable_amp: gscaler.scale(loss).backward() gscaler.step(optimizer) gscaler.update() else: loss.backward() optimizer.step() bw_time += time.time() torch.cuda.nvtx.range_pop() iters += 1 epoch_step += 1 # step done tr_end = time.time() tr_time += tr_end - tr_start torch.cuda.nvtx.range_pop() # epoch done if args.global_timing: dist.barrier() end = time.time() epoch_time = end - start step_time = epoch_time / float(epoch_step) tr_time /= float(epoch_step) fw_time /= float(epoch_step) bw_time /= float(epoch_step) io_time = max([step_time - fw_time - bw_time, 0]) iters_per_sec = 1. / step_time fw_per_sec = 1. / tr_time if world_rank == 0: logging.info('Time taken for epoch {} is {} sec'.format( epoch + 1, epoch_time)) logging.info( 'train step time = {} ({} steps), logging time = {}'. format(tr_time, epoch_step, log_time)) logging.info('train samples/sec = {} fw steps/sec = {}'.format( iters_per_sec, fw_per_sec))