def train(params, args, world_rank, local_rank): #logging info logging.info('rank {:d}, begin data loader init (local rank {:d})'.format( world_rank, local_rank)) #train_data_loader = get_data_loader_distributed(params, world_rank) train_data_loader = get_data_loader_distributed(params, world_rank, local_rank) logging.info('rank %d, data loader initialized' % world_rank) # set device device = torch.device("cuda:{}".format(local_rank)) model = UNet.UNet(params) model.to(device) if not args.resuming: model.apply(model.get_weights_function(params.weight_init)) optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr) #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision if params.distributed: model = DDP(model, device_ids=[local_rank]) # amp stuff #gscaler = amp.GradScaler() iters = 0 startEpoch = 0 checkpoint = None if args.resuming: if world_rank == 0: logging.info("Loading checkpoint %s" % params.checkpoint_path) checkpoint = torch.load(params.checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state']) iters = checkpoint['iters'] startEpoch = checkpoint['epoch'] + 1 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if world_rank == 0: logging.info(model) logging.info("Starting Training Loop...") for epoch in range(startEpoch, startEpoch + params.num_epochs): start = time.time() nsteps = 0 fw_time = 0. bw_time = 0. log_time = 0. model.train() step_time = time.time() for i, data in enumerate(train_data_loader, 0): iters += 1 #adjust_LR(optimizer, params, iters) inp, tar = map(lambda x: x.to(device), data) if not args.io_only: # fw pass fw_time -= time.time() optimizer.zero_grad() #with amp.autocast(): gen = model(inp) loss = UNet.loss_func(gen, tar, params) fw_time += time.time() # bw pass bw_time -= time.time() loss.backward() optimizer.step() #gscaler.scale(loss).backward() #gscaler.step(optimizer) #gscaler.update() bw_time += time.time() nsteps += 1 # epoch done dist.barrier() step_time = (time.time() - step_time) / float(nsteps) fw_time /= float(nsteps) bw_time /= float(nsteps) io_time = max([step_time - fw_time - bw_time, 0]) iters_per_sec = 1. / step_time ## Output training stats #model.eval() #if world_rank==0: # log_start = time.time() # gens = [] # tars = [] # with torch.no_grad(): # for i, data in enumerate(train_data_loader, 0): # if i>=16: # break # #inp, tar = map(lambda x: x.to(device), data) # inp, tar = data # gen = model(inp) # gens.append(gen.detach().cpu().numpy()) # tars.append(tar.detach().cpu().numpy()) # gens = np.concatenate(gens, axis=0) # tars = np.concatenate(tars, axis=0) # # # Scalars # args.tboard_writer.add_scalar('G_loss', loss.item(), iters) # # # Plots # fig, chi, L1score = meanL1(gens, tars) # args.tboard_writer.add_figure('pixhist', fig, iters, close=True) # args.tboard_writer.add_scalar('Metrics/chi', chi, iters) # args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters) # args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters) # args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters) # args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters) # args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters) # # fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1]) # args.tboard_writer.add_figure('genimg', fig, iters, close=True) # log_end = time.time() # log_time += log_end - log_start # # Save checkpoint # torch.save({'iters': iters, 'epoch':epoch, 'model_state': model.state_dict(), # 'optimizer_state_dict': optimizer.state_dict()}, params.checkpoint_path) end = time.time() if world_rank == 0: logging.info('Time taken for epoch {} is {} sec'.format( epoch + 1, end - start)) logging.info( 'total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}' .format(step_time, fw_time, bw_time, io_time, iters_per_sec, log_time)) # finalize dist.barrier()
def train(params, args, world_rank): logging.info('rank %d, begin data loader init' % world_rank) train_data_loader = get_data_loader_distributed(params, world_rank) test_data_loader = get_data_loader_distributed_test(params, world_rank) logging.info('rank %d, data loader initialized' % world_rank) model = UNet.UNet(params).cuda() if not args.resuming: model.apply(model.get_weights_function(params.weight_init)) optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr) #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision if params.distributed: model = DistributedDataParallel(model) iters = 0 startEpoch = 0 checkpoint = None if args.resuming: if world_rank == 0: logging.info("Loading checkpoint %s" % params.checkpoint_path) checkpoint = torch.load(params.checkpoint_path, map_location='cuda:{}'.format(args.local_rank)) model.load_state_dict(checkpoint['model_state']) iters = checkpoint['iters'] startEpoch = checkpoint['epoch'] + 1 optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if world_rank == 0: logging.info(model) logging.info("Starting Training Loop...") device = torch.cuda.current_device() for epoch in range(startEpoch, startEpoch + params.num_epochs): start = time.time() tr_time = 0. log_time = 0. for i, data in enumerate(train_data_loader, 0): iters += 1 adjust_LR(optimizer, params, iters) inp, tar = map(lambda x: x.to(device), data) tr_start = time.time() b_size = inp.size(0) model.zero_grad() gen = model(inp) loss = UNet.loss_func(gen, tar, params) loss.backward() # fixed precision # automatic mixed precision: #with amp.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() optimizer.step() tr_end = time.time() tr_time += tr_end - tr_start # Output training stats if world_rank == 0: log_start = time.time() gens = [] tars = [] with torch.no_grad(): for i, data in enumerate(test_data_loader, 0): if i >= 50: break inp, tar = map(lambda x: x.to(device), data) gen = model(inp) gens.append(gen.detach().cpu().numpy()) tars.append(tar.detach().cpu().numpy()) gens = np.concatenate(gens, axis=0) tars = np.concatenate(tars, axis=0) # Scalars args.tboard_writer.add_scalar('G_loss', loss.item(), iters) # Plots fig = plot_gens_tars(gens, tars) #fig, chi, L1score = meanL1(gens, tars) #args.tboard_writer.add_figure('pixhist', fig, iters, close=True) #args.tboard_writer.add_scalar('Metrics/chi', chi, iters) #args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters) #args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters) #args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters) #args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters) #args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters) # #fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1]) for figiter in range(5): figtag = 'test' + str(figiter) args.tboard_writer.add_figure(tag=figtag, figure=fig[figiter], close=True) #log_end = time.time() #log_time += log_end - log_start # Save checkpoint torch.save( { 'iters': iters, 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, params.checkpoint_path) end = time.time() if world_rank == 0: logging.info('Time taken for epoch {} is {} sec'.format( epoch + 1, end - start)) logging.info('train step time={}, logging time={}'.format( tr_time, log_time))