def train(exp=None): """ main function to run the training """ encoder = Encoder(encoder_params[0], encoder_params[1]).cuda() decoder = Decoder(decoder_params[0], decoder_params[1]).cuda() net = ED(encoder, decoder) run_dir = "./runs/" + TIMESTAMP if not os.path.isdir(run_dir): os.makedirs(run_dir) # tb = SummaryWriter(run_dir) # initialize the early_stopping object early_stopping = EarlyStopping(patience=20, verbose=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) if os.path.exists(os.path.join(save_dir, "checkpoint.pth.tar")): # load existing model print("==> loading existing model") model_info = torch.load(os.path.join(save_dir, "checkpoin.pth.tar")) net.load_state_dict(model_info["state_dict"]) optimizer = torch.optim.Adam(net.parameters()) optimizer.load_state_dict(model_info["optimizer"]) cur_epoch = model_info["epoch"] + 1 else: if not os.path.isdir(save_dir): os.makedirs(save_dir) cur_epoch = 0 lossfunction = nn.MSELoss().cuda() optimizer = optim.Adam(net.parameters(), lr=args.lr) pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=4, verbose=True) # to track the training loss as the model trains train_losses = [] # to track the validation loss as the model trains valid_losses = [] # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] # mini_val_loss = np.inf for epoch in range(cur_epoch, args.epochs + 1): if exp is not None: exp.log_metric("epoch", epoch) ################### # train the model # ################### t = tqdm(trainLoader, leave=False, total=len(trainLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): inputs = inputVar.to(device) # B,S,C,H,W label = targetVar.to(device) # B,S,C,H,W optimizer.zero_grad() net.train() pred = net(inputs) # B,S,C,H,W loss = lossfunction(pred, label) loss_aver = loss.item() / args.batch_size train_losses.append(loss_aver) loss.backward() torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0) optimizer.step() t.set_postfix({ "trainloss": "{:.6f}".format(loss_aver), "epoch": "{:02d}".format(epoch), }) # tb.add_scalar('TrainLoss', loss_aver, epoch) ###################### # validate the model # ###################### with torch.no_grad(): net.eval() t = tqdm(validLoader, leave=False, total=len(validLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): if i == 3000: break inputs = inputVar.to(device) label = targetVar.to(device) pred = net(inputs) loss = lossfunction(pred, label) loss_aver = loss.item() / args.batch_size # record validation loss valid_losses.append(loss_aver) # print ("validloss: {:.6f}, epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True) t.set_postfix({ "validloss": "{:.6f}".format(loss_aver), "epoch": "{:02d}".format(epoch), }) # tb.add_scalar('ValidLoss', loss_aver, epoch) torch.cuda.empty_cache() # print training/validation statistics # calculate average loss over an epoch train_loss = np.average(train_losses) valid_loss = np.average(valid_losses) avg_train_losses.append(train_loss) avg_valid_losses.append(valid_loss) epoch_len = len(str(args.epochs)) print_msg = (f"[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] " + f"train_loss: {train_loss:.6f} " + f"valid_loss: {valid_loss:.6f}") # print(print_msg) # clear lists to track next epoch if exp is not None: exp.log_metric("TrainLoss", train_loss) exp.log_metric("ValidLoss", valid_loss) train_losses = [] valid_losses = [] pla_lr_scheduler.step(valid_loss) # lr_scheduler model_dict = { "epoch": epoch, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict(), } if epoch % args.save_every == 0 and epoch != 0: torch.save( model_dict, save_dir + "/" + "checkpoint_{}_{:.6f}.pth.tar".format( epoch, valid_loss.item()), ) early_stopping(valid_loss.item(), model_dict, epoch, save_dir) if early_stopping.early_stop: print("Early stopping") break with open("avg_train_losses.txt", "wt") as f: for i in avg_train_losses: print(i, file=f) with open("avg_valid_losses.txt", "wt") as f: for i in avg_valid_losses: print(i, file=f)
def test(): ''' main function to run the training ''' testFolder = MovingMNIST(is_train=False, root='../data/npy-064/', mode ='test', n_frames_input=args.frames_input, n_frames_output=args.frames_output, num_objects=[3]) testLoader = torch.utils.data.DataLoader(testFolder, batch_size=args.batch_size, shuffle=False) if args.convlstm: encoder_params = convlstm_encoder_params decoder_params = convlstm_decoder_params if args.convgru: encoder_params = convgru_encoder_params decoder_params = convgru_decoder_params else: encoder_params = convgru_encoder_params decoder_params = convgru_decoder_params #TIMESTAMP = args.timestamp # restore args CHECKPOINT = args.checkpoint TIMESTAMP = args.timestamp save_dir = './save_model/' + TIMESTAMP args_path = os.path.join(save_dir, 'cmd_args.txt') if os.path.exists(args_path): with open(args_path, 'r') as f: args.__dict__ = json.load(f) args.is_train = False encoder = Encoder(encoder_params[0], encoder_params[1]).cuda() decoder = Decoder(decoder_params[0], decoder_params[1], args.frames_output).cuda() net = ED(encoder, decoder) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) if os.path.exists(save_dir): # load existing model print('==> loading existing model') model_info = torch.load(CHECKPOINT) net.load_state_dict(model_info['state_dict']) optimizer = torch.optim.Adam(net.parameters()) optimizer.load_state_dict(model_info['optimizer']) else: print('there is no such checkpoint in', save_dir) exit() lossfunction = nn.MSELoss().cuda() # to track the testation loss as the model trains test_losses = [] # to track the average training loss per epoch as the model trains avg_test_losses = [] # mini_val_loss = np.inf preds = [] ###################### # testate the model # ###################### with torch.no_grad(): net.eval() t = tqdm(testLoader, leave=False, total=len(testLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): if i == 3000: break inputs = inputVar.to(device) #label = targetVar.to(device) pred = net(inputs) #loss = lossfunction(pred, label) preds.append(pred) #loss_aver = loss.item() / args.batch_size # record testation loss #test_losses.append(loss_aver) torch.cuda.empty_cache() # print training/testation statistics # calculate average loss over an epoch #test_loss = np.average(test_losses) #avg_test_losses.append(test_loss) #print_msg = (f'test_loss: {test_loss:.6f}') #print(print_msg) import pickle with open("preds.pkl", "wb") as fp: pickle.dump(preds, fp)
def train(): ''' main function to run the training ''' # 实例化Encoder和Decoder encoder = Encoder(encoder_params[0], encoder_params[1]).cuda() decoder = Decoder(decoder_params[0], decoder_params[1]).cuda() # 实例化ED net = ED(encoder, decoder) # 运行目录 run_dir = './runs/' + TIMESTAMP # 如果 run_dir 不存在则创建目录 if not os.path.isdir(run_dir): os.makedirs(run_dir) tb = SummaryWriter(run_dir) # initialize the early_stopping object early_stopping = EarlyStopping(patience=20, verbose=True) # 判断CUDA是否可用 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 如果GPU大于一块,使用并行计算 if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) # 判断checkpoint.pth.tar文件是否存在 if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')): # 加载保存的模型 print('==> loading existing model') model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar')) net.load_state_dict(model_info['state_dict']) optimizer = torch.optim.Adam(net.parameters()) optimizer.load_state_dict(model_info['optimizer']) cur_epoch = model_info['epoch'] + 1 else: # 如果checkpoint.pth.tar不存在则判断save_dir是否存在,不存在则创建 if not os.path.isdir(save_dir): os.makedirs(save_dir) # 将当前epoch初始化为0 cur_epoch = 0 # 损失函数使用MSELoss lossfunction = nn.MSELoss().cuda() # 优化器使用Adam optimizer = optim.Adam(net.parameters(), lr=args.lr) pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=4, verbose=True) # to track the training loss as the model trains train_losses = [] # to track the validation loss as the model trains valid_losses = [] # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] # mini_val_loss = np.inf for epoch in range(cur_epoch, args.epochs + 1): ################### # train the model # ################### t = tqdm(trainLoader, leave=False, total=len(trainLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): inputs = inputVar.to(device) # B,S,C,H,W label = targetVar.to(device) # B,S,C,H,W optimizer.zero_grad() net.train() pred = net(inputs) # B,S,C,H,W loss = lossfunction(pred, label) loss_aver = loss.item() / args.batch_size train_losses.append(loss_aver) loss.backward() torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0) optimizer.step() t.set_postfix({ 'trainloss': '{:.6f}'.format(loss_aver), 'epoch': '{:02d}'.format(epoch) }) tb.add_scalar('TrainLoss', loss_aver, epoch) ###################### # validate the model # ###################### with torch.no_grad(): net.eval() t = tqdm(validLoader, leave=False, total=len(validLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): if i == 3000: break inputs = inputVar.to(device) label = targetVar.to(device) pred = net(inputs) loss = lossfunction(pred, label) loss_aver = loss.item() / args.batch_size # record validation loss valid_losses.append(loss_aver) #print ("validloss: {:.6f}, epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True) t.set_postfix({ 'validloss': '{:.6f}'.format(loss_aver), 'epoch': '{:02d}'.format(epoch) }) tb.add_scalar('ValidLoss', loss_aver, epoch) torch.cuda.empty_cache() # print training/validation statistics # calculate average loss over an epoch train_loss = np.average(train_losses) valid_loss = np.average(valid_losses) avg_train_losses.append(train_loss) avg_valid_losses.append(valid_loss) epoch_len = len(str(args.epochs)) print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' + f'train_loss: {train_loss:.6f} ' + f'valid_loss: {valid_loss:.6f}') print(print_msg) # clear lists to track next epoch train_losses = [] valid_losses = [] pla_lr_scheduler.step(valid_loss) # lr_scheduler model_dict = { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } early_stopping(valid_loss.item(), model_dict, epoch, save_dir) if early_stopping.early_stop: print("Early stopping") break with open("avg_train_losses.txt", 'wt') as f: for i in avg_train_losses: print(i, file=f) with open("avg_valid_losses.txt", 'wt') as f: for i in avg_valid_losses: print(i, file=f)
def test(): ''' main function to run the testing ''' encoder = Encoder(encoder_params[0], encoder_params[1]).cuda() decoder = Decoder(decoder_params[0], decoder_params[1]).cuda() net = ED(encoder, decoder) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) # 加载待测试模型 if os.path.exists(args.model_path): # load existing model print('==> loading existing model ' + args.model_path) model_info = torch.load(args.model_path) net.load_state_dict(model_info['state_dict']) model_dir = args.model_path.split('/')[-2] else: raise Exception("Invalid model path!") # 创建存储可视化图片的路径 if not os.path.isdir(args.vis_dir): os.makedirs(args.vis_dir) class_weights = torch.FloatTensor([1.0, 15.0]).cuda() lossfunction = nn.CrossEntropyLoss(weight=class_weights).cuda() # to track the testing loss as the model testing test_losses = [] # to track the average testing loss per epoch as the model testing avg_test_losses = [] ###################### # test the model # ###################### with torch.no_grad(): net.eval() # 将module设置为 eval mode,只影响dropout和batchNorm # tqdm 进度条 t = tqdm(testLoader, total=len(testLoader)) for i, (seq_len, scan_seq, label_seq, mask_seq, label_id) in enumerate(t): # 序列长度不固定,至少前2帧用来输入,固定预测后3帧 inputs = inputs = torch.cat((scan_seq, mask_seq.float()), dim=2).to(device)[:, :-3, ...] # B,S,C,H,W label = mask_seq.to(device)[:, (seq_len - 3):, ...] # B,S,C,H,W pred = net(inputs) SaveVis(model_dir, i, scan_seq.to(device), mask_seq.to(device), pred) seq_number, batch_size, input_channel, height, width = pred.size() pred = pred.reshape(-1, input_channel, height, width) # reshape to B*S,C,H,W seq_number, batch_size, input_channel, height, width = label.size() label = label.reshape(-1, height, width) # reshape to B*S,H,W label = label.to(device=device, dtype=torch.long) loss = lossfunction(pred, label) loss_aver = loss.item() / (label.shape[0]) # record test loss test_losses.append(loss_aver) t.set_postfix({ 'test_loss': '{:.6f}'.format(loss_aver), 'cnt': '{:02d}'.format(i) }) # 参数中限制了要测试的样本数量 if i >= args.sample and args.sample > 0: break torch.cuda.empty_cache() # print test statistics # calculate average loss over an epoch test_loss = np.average(test_losses) avg_test_losses.append(test_loss) # epoch_len = len(str(args.epochs)) test_losses = []
def test(): ''' main function to run the training ''' #TIMESTAMP = args.timestamp # restore args CHECKPOINT = args.checkpoint TIMESTAMP = args.timestamp save_dir = './save_model/' + TIMESTAMP args_path = os.path.join(save_dir, 'cmd_args.txt') if os.path.exists(args_path): with open(args_path, 'r') as f: args.__dict__ = json.load(f) args.is_train = False encoder = Encoder(encoder_params[0], encoder_params[1]).cuda() decoder = Decoder(decoder_params[0], decoder_params[1], args.frames_output).cuda() net = ED(encoder, decoder) #run_dir = './runs/' + TIMESTAMP #if not os.path.isdir(run_dir): # os.makedirs(run_dir) #tb = SummaryWriter(run_dir) # initialize the early_stopping object #early_stopping = EarlyStopping(patience=20, verbose=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) if os.path.exists(save_dir): # load existing model print('==> loading existing model') model_info = torch.load(CHECKPOINT) net.load_state_dict(model_info['state_dict']) optimizer = torch.optim.Adam(net.parameters()) optimizer.load_state_dict(model_info['optimizer']) else: print('there is no such checkpoin') exit() lossfunction = nn.MSELoss().cuda() """ #optimizer = optim.Adam(net.parameters(), lr=args.lr) #pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=4, verbose=True) """ # to track the training loss as the model trains train_losses = [] # to track the validation loss as the model trains valid_losses = [] # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] # mini_val_loss = np.inf preds = [] ###################### # validate the model # ###################### with torch.no_grad(): net.eval() t = tqdm(validLoader, leave=False, total=len(validLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): if i == 3000: break inputs = inputVar.to(device) label = targetVar.to(device) pred = net(inputs) loss = lossfunction(pred, label) preds.append(pred) loss_aver = loss.item() / args.batch_size # record validation loss valid_losses.append(loss_aver) torch.cuda.empty_cache() # print training/validation statistics # calculate average loss over an epoch valid_loss = np.average(valid_losses) #avg_valid_losses.append(valid_loss) print_msg = (f'valid_loss: {valid_loss:.6f}') print(print_msg) import pickle with open("preds.pkl", "wb") as fp: pickle.dump(preds, fp)
def train(): ''' main function to run the training ''' restore = False #TIMESTAMP = "2020-03-09T00-00-00" if args.timestamp == "NA": TIMESTAMP = datetime.now().strftime("%b%d-%H%M%S") print('TIMESTAMP', TIMESTAMP) else: # restore restore = True TIMESTAMP = args.timestamp save_dir = './save_model/' + TIMESTAMP if restore: # restore args with open(os.path.join(save_dir, 'cmd_args.txt'), 'r') as f: args.__dict__ = json.load(f) encoder = Encoder(encoder_params[0], encoder_params[1]).cuda() decoder = Decoder(decoder_params[0], decoder_params[1], args.frames_output).cuda() net = ED(encoder, decoder) run_dir = './runs/' + TIMESTAMP if not os.path.isdir(run_dir): os.makedirs(run_dir) tb = SummaryWriter(run_dir) # initialize the early_stopping object early_stopping = EarlyStopping(patience=30, verbose=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')): # load existing model print('==> loading existing model') model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar')) net.load_state_dict(model_info['state_dict']) optimizer = torch.optim.Adam(net.parameters()) optimizer.load_state_dict(model_info['optimizer']) cur_epoch = model_info['epoch'] + 1 else: if not os.path.isdir(save_dir): os.makedirs(save_dir) cur_epoch = 0 lossfunction = nn.MSELoss().cuda() optimizer = optim.Adam(net.parameters(), lr=args.lr) pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=4, verbose=True) # to track the training loss as the model trains train_losses = [] # to track the validation loss as the model trains valid_losses = [] # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] # mini_val_loss = np.inf for epoch in range(cur_epoch, args.epochs + 1): ################### # train the model # ################### t = tqdm(trainLoader, leave=False, total=len(trainLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): inputs = inputVar.to(device) # B,S,C,H,W label = targetVar.to(device) # B,S,C,H,W optimizer.zero_grad() net.train() pred = net(inputs) # B,S,C,H,W loss = lossfunction(pred, label) loss_aver = loss.item() / args.batch_size train_losses.append(loss_aver) loss.backward() torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0) optimizer.step() t.set_postfix({ 'trainloss': '{:.6f}'.format(loss_aver), 'epoch': '{:02d}'.format(epoch) }) tb.add_scalar('TrainLoss', loss_aver, epoch) ###################### # validate the model # ###################### with torch.no_grad(): net.eval() t = tqdm(validLoader, leave=False, total=len(validLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): if i == 3000: break inputs = inputVar.to(device) label = targetVar.to(device) pred = net(inputs) loss = lossfunction(pred, label) loss_aver = loss.item() / args.batch_size # record validation loss valid_losses.append(loss_aver) #print ("validloss: {:.6f}, epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True) t.set_postfix({ 'validloss': '{:.6f}'.format(loss_aver), 'epoch': '{:02d}'.format(epoch) }) tb.add_scalar('ValidLoss', loss_aver, epoch) torch.cuda.empty_cache() # print training/validation statistics # calculate average loss over an epoch train_loss = np.average(train_losses) valid_loss = np.average(valid_losses) avg_train_losses.append(train_loss) avg_valid_losses.append(valid_loss) epoch_len = len(str(args.epochs)) print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' + f'train_loss: {train_loss:.6f} ' + f'valid_loss: {valid_loss:.6f}') print(print_msg) # clear lists to track next epoch train_losses = [] valid_losses = [] pla_lr_scheduler.step(valid_loss) # lr_scheduler model_dict = { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } early_stopping(valid_loss.item(), model_dict, epoch, save_dir) if early_stopping.early_stop: print("Early stopping") break with open("avg_train_losses.txt", 'wt') as f: for i in avg_train_losses: print(i, file=f) with open("avg_valid_losses.txt", 'wt') as f: for i in avg_valid_losses: print(i, file=f) # save args if not restore: with open(os.path.join(save_dir, 'cmd_args.txt'), 'w+') as f: json.dump(args.__dict__, f, indent=2)
def train(exp=None): """ main function to run the training """ encoder = Encoder(encoder_params[0], encoder_params[1]).cuda() decoder = Decoder(decoder_params[0], decoder_params[1]).cuda() net = ED(encoder, decoder) run_dir = "./runs/" + TIMESTAMP if not os.path.isdir(run_dir): os.makedirs(run_dir) # tb = SummaryWriter(run_dir) # initialize the early_stopping object early_stopping = EarlyStopping(patience=20, verbose=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) if os.path.exists(args.checkpoint) and args.continue_train: # load existing model print("==> loading existing model") model_info = torch.load(args.checkpoint) net.load_state_dict(model_info["state_dict"]) optimizer = torch.optim.Adam(net.parameters()) optimizer.load_state_dict(model_info["optimizer"]) cur_epoch = model_info["epoch"] + 1 else: if not os.path.isdir(save_dir): os.makedirs(save_dir) cur_epoch = 0 lossfunction = nn.MSELoss().cuda() optimizer = optim.Adam(net.parameters(), lr=args.lr) pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=4, verbose=True) # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] # pnsr ssim avg_psnrs = {} avg_ssims = {} for j in range(args.frames_output): avg_psnrs[j] = [] avg_ssims[j] = [] if args.checkdata: # Checking dataloader print("Checking Dataloader!") t = tqdm(trainLoader, leave=False, total=len(trainLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): assert targetVar.shape == torch.Size([ args.batchsize, args.frames_output, 1, args.data_h, args.data_w ]) assert inputVar.shape == torch.Size([ args.batchsize, args.frames_input, 1, args.data_h, args.data_w ]) print("TrainLoader checking is complete!") t = tqdm(validLoader, leave=False, total=len(validLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): assert targetVar.shape == torch.Size([ args.batchsize, args.frames_output, 1, args.data_h, args.data_w ]) assert inputVar.shape == torch.Size([ args.batchsize, args.frames_input, 1, args.data_h, args.data_w ]) print("ValidLoader checking is complete!") # mini_val_loss = np.inf for epoch in range(cur_epoch, args.epochs + 1): # to track the training loss as the model trains train_losses = [] # to track the validation loss as the model trains valid_losses = [] psnr_dict = {} ssim_dict = {} for j in range(args.frames_output): psnr_dict[j] = 0 ssim_dict[j] = 0 image_log = [] if exp is not None: exp.log_metric("epoch", epoch) ################### # train the model # ################### t = tqdm(trainLoader, leave=False, total=len(trainLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): inputs = inputVar.to(device) # B,S,C,H,W label = targetVar.to(device) # B,S,C,H,W optimizer.zero_grad() net.train() pred = net(inputs) # B,S,C,H,W loss = lossfunction(pred, label) loss_aver = loss.item() / args.batchsize train_losses.append(loss_aver) loss.backward() torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0) optimizer.step() t.set_postfix({ "trainloss": "{:.6f}".format(loss_aver), "epoch": "{:02d}".format(epoch), }) # tb.add_scalar('TrainLoss', loss_aver, epoch) ###################### # validate the model # ###################### with torch.no_grad(): net.eval() t = tqdm(validLoader, leave=False, total=len(validLoader)) for i, (idx, targetVar, inputVar, _, _) in enumerate(t): inputs = inputVar.to(device) label = targetVar.to(device) pred = net(inputs) loss = lossfunction(pred, label) loss_aver = loss.item() / args.batchsize # record validation loss valid_losses.append(loss_aver) for j in range(args.frames_output): psnr_dict[j] += psnr(pred[:, j], label[:, j]) ssim_dict[j] += ssim(pred[:, j], label[:, j]) # print ("validloss: {:.6f}, epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True) t.set_postfix({ "validloss": "{:.6f}".format(loss_aver), "epoch": "{:02d}".format(epoch), }) if i % 500 == 499: for k in range(args.frames_output): image_log.append(label[0, k].unsqueeze(0).repeat( 1, 3, 1, 1)) image_log.append(pred[0, k].unsqueeze(0).repeat( 1, 3, 1, 1)) upload_images( image_log, epoch, exp=exp, im_per_row=2, rows_per_log=int(len(image_log) / 2), ) # tb.add_scalar('ValidLoss', loss_aver, epoch) torch.cuda.empty_cache() # print training/validation statistics # calculate average loss over an epoch train_loss = np.average(train_losses) valid_loss = np.average(valid_losses) avg_train_losses.append(train_loss) avg_valid_losses.append(valid_loss) for j in range(args.frames_output): avg_psnrs[j].append(psnr_dict[j] / i) avg_ssims[j].append(ssim_dict[j] / i) epoch_len = len(str(args.epochs)) print_msg = (f"[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] " + f"train_loss: {train_loss:.6f} " + f"valid_loss: {valid_loss:.6f}" + f"PSNR_1: {psnr_dict[0] / i:.6f}" + f"SSIM_1: {ssim_dict[0] / i:.6f}") # print(print_msg) # clear lists to track next epoch if exp is not None: exp.log_metric("TrainLoss", train_loss) exp.log_metric("ValidLoss", valid_loss) exp.log_metric("PSNR_1", psnr_dict[0] / i) exp.log_metric("SSIM_1", ssim_dict[0] / i) pla_lr_scheduler.step(valid_loss) # lr_scheduler model_dict = { "epoch": epoch, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict(), "avg_psnrs": avg_psnrs, "avg_ssims": avg_ssims, "avg_valid_losses": avg_valid_losses, "avg_train_losses": avg_train_losses, } save_flag = False if epoch % args.save_every == 0: torch.save( model_dict, save_dir + "/" + "checkpoint_{}_{:.6f}.pth".format(epoch, valid_loss.item()), ) print("Saved" + "checkpoint_{}_{:.6f}.pth".format(epoch, valid_loss.item())) save_flag = True if avg_psnrs[0][-1] == max(avg_psnrs[0]) and not save_flag: torch.save( model_dict, save_dir + "/" + "bestpsnr_1.pth", ) print("Best psnr found and saved") save_flag = True if avg_ssims[0][-1] == max(avg_ssims[0]) and not save_flag: torch.save( model_dict, save_dir + "/" + "bestssim_1.pth", ) print("Best ssim found and saved") save_flag = True if avg_valid_losses[-1] == min(avg_valid_losses) and not save_flag: torch.save( model_dict, save_dir + "/" + "bestvalidloss.pth", ) print("Best validloss found and saved") save_flag = True if not save_flag: torch.save( model_dict, save_dir + "/" + "checkpoint.pth", ) print("The latest normal checkpoint saved") early_stopping(valid_loss.item(), model_dict, epoch, save_dir) if early_stopping.early_stop: print("Early stopping") break with open("avg_train_losses.txt", "wt") as f: for i in avg_train_losses: print(i, file=f) with open("avg_valid_losses.txt", "wt") as f: for i in avg_valid_losses: print(i, file=f)
def train(): ''' main function to run the training ''' encoder = Encoder(encoder_params[0], encoder_params[1]).cuda() decoder = Decoder(decoder_params[0], decoder_params[1]).cuda() net = ED(encoder, decoder) run_dir = './runs/' + TIMESTAMP + args.mname if not os.path.isdir(run_dir): os.makedirs(run_dir) tb = SummaryWriter(run_dir) # initialize the early_stopping object early_stopping = EarlyStopping(patience=200, verbose=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.to(device) if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')): # load existing model print('==> loading existing model') model_info = torch.load(os.path.join(save_dir, 'checkpoin.pth.tar')) net.load_state_dict(model_info['state_dict']) optimizer = torch.optim.Adam(net.parameters()) optimizer.load_state_dict(model_info['optimizer']) cur_epoch = model_info['epoch'] + 1 else: if not os.path.isdir(save_dir): os.makedirs(save_dir) cur_epoch = 0 class_weights = torch.FloatTensor([1.0, 15.0]).cuda() lossfunction = nn.CrossEntropyLoss(weight=class_weights).cuda() optimizer = optim.Adam(net.parameters(), lr=args.lr) pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=True) # to track the training loss as the model trains train_losses = [] # to track the validation loss as the model trains valid_losses = [] # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] min_train_loss = np.inf for epoch in range(cur_epoch, args.epochs + 1): print(time.strftime("now time: %Y%m%d_%H:%M", time.localtime(time.time()))) ################### # train the model # ################### # tqdm 进度条 t = tqdm(trainLoader, total=len(trainLoader)) for i, (seq_len, scan_seq, _, mask_seq, _) in enumerate(t): # 序列长度不固定,至少前2帧用来输入,固定预测后3帧 inputs = inputs = torch.cat((scan_seq, mask_seq.float()), dim=2).to(device)[:,:-3,...] # B,S,C,H,W label = mask_seq.to(device)[:,(seq_len-3):,...] # B,S,C,H,W optimizer.zero_grad() net.train() # 将module设置为 training mode,只影响dropout和batchNorm pred = net(inputs) # B,S,C,H,W # 在tensorboard中绘制可视化结果 if i % 100 == 0: grid_ri_lab, grid_pred = get_visualization_example(scan_seq.to(device), mask_seq.to(device), pred, device) tb.add_image('visualization/train/rangeImage_gtMask', grid_ri_lab, global_step=epoch) tb.add_image('visualization/train/prediction', grid_pred, global_step=epoch) seq_number, batch_size, input_channel, height, width = pred.size() pred = pred.reshape(-1, input_channel, height, width) # reshape to B*S,C,H,W seq_number, batch_size, input_channel, height, width = label.size() label = label.reshape(-1, height, width) # reshape to B*S,H,W label = label.to(device=device, dtype=torch.long) # 计算loss loss = lossfunction(pred, label) loss_aver = loss.item() / (label.shape[0] * batch_size) train_losses.append(loss_aver) loss.backward() # 防止梯度爆炸,进行梯度裁剪,指定clip_value之后,裁剪的范围就是[-clip_value, clip_value] torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=30.0) optimizer.step() t.set_postfix({ 'trainloss': '{:.6f}'.format(loss_aver), 'epoch': '{:02d}'.format(epoch) }) tb.add_scalar('TrainLoss', np.average(train_losses), epoch) ###################### # validate the model # ###################### with torch.no_grad(): # 将module设置为 eval模式, 只影响dropout和batchNorm net.eval() # tqdm 进度条 t = tqdm(validLoader, total=len(validLoader)) for i, (seq_len, scan_seq, _, mask_seq, _) in enumerate(t): if i == 300: # 限制 validate 数量 break # 序列长度不固定,至少前2帧用来输入,固定预测后3帧 inputs = torch.cat((scan_seq, mask_seq.float()), dim=2).to(device) # B,S,C,H,W label = mask_seq.to(device)[:,(seq_len-3):,...] # B,S,C,H,W pred = net(inputs) # 在tensorboard中绘制可视化结果 if i % 100 == 0: grid_ri_lab, grid_pred = get_visualization_example(scan_seq.to(device), mask_seq.to(device), pred, device) tb.add_image('visualization/valid/rangeImage_gtMask', grid_ri_lab, global_step=epoch) tb.add_image('visualization/valid/prediction', grid_pred, global_step=epoch) seq_number, batch_size, input_channel, height, width = pred.size() pred = pred.reshape(-1, input_channel, height, width) # reshape to B*S,C,H,W seq_number, batch_size, input_channel, height, width = label.size() label = label.reshape(-1, height, width) # reshape to B*S,H,W label = label.to(device=device, dtype=torch.long) loss = lossfunction(pred, label) loss_aver = loss.item() / (label.shape[0] * batch_size) # record validation loss valid_losses.append(loss_aver) #print ("validloss: {:.6f}, epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True) t.set_postfix({ 'validloss': '{:.6f}'.format(loss_aver), 'epoch': '{:02d}'.format(epoch) }) # get_visualization_example(inputs, label, pred) tb.add_scalar('ValidLoss', np.average(valid_losses), epoch) torch.cuda.empty_cache() # print training/validation statistics # calculate average loss over an epoch train_loss = np.average(train_losses) valid_loss = np.average(valid_losses) avg_train_losses.append(train_loss) avg_valid_losses.append(valid_loss) epoch_len = len(str(args.epochs)) print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' + f'train_loss: {train_loss:.6f} ' + f'valid_loss: {valid_loss:.6f}') print(print_msg) # clear lists to track next epoch train_losses = [] valid_losses = [] pla_lr_scheduler.step(valid_loss) # lr_scheduler model_dict = { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict() } # 保存train loss最低的模型 if (train_loss < min_train_loss): torch.save(model_dict, save_dir + "/" + "best_train_checkpoint.pth.tar") min_train_loss = train_loss # 保存valid loss最低的模型 early_stopping(valid_loss.item(), model_dict, epoch, save_dir) if early_stopping.early_stop: print("Early stopping") break # end for with open("avg_train_losses.txt", 'wt') as f: for i in avg_train_losses: print(i, file=f) with open("avg_valid_losses.txt", 'wt') as f: for i in avg_valid_losses: print(i, file=f)