def main(args): wandb.init(project="crowd", config=args) args = wandb.config # print(args) # vis=visdom.Visdom() torch.cuda.manual_seed(args.seed) model=CANNet().to(args.device) criterion=nn.MSELoss(reduction='sum').to(args.device) # optimizer=torch.optim.SGD(model.parameters(), args.lr, # momentum=args.momentum, # weight_decay=0) optimizer=torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.decay) train_dataset = CrowdDataset(args.train_image_root, args.train_dmap_root, gt_downsample=8, phase='train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_dataset = CrowdDataset(args.val_image_root, args.val_dmap_root, gt_downsample=8, phase='test') val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) if not os.path.exists('./checkpoints'): os.mkdir('./checkpoints') min_mae = 10000 min_epoch = 0 for epoch in tqdm(range(0, args.epochs)): # training phase model.train() model.zero_grad() train_loss = 0 train_mae = 0 train_bar = tqdm(train_loader) for i, (img,gt_dmap) in enumerate(train_bar): # print(img.shape, gt_dmap.shape) img = img.to(args.device) gt_dmap = gt_dmap.to(args.device) # forward propagation et_dmap = model(img) # calculate loss # print(et_dmap.shape, gt_dmap.shape) loss = criterion(et_dmap, gt_dmap) train_loss += loss.item() train_mae += abs(et_dmap.data.sum()-gt_dmap.data.sum()).item() loss = loss/args.gradient_accumulation_steps loss.backward() if (i+1)%args.gradient_accumulation_steps == 0: optimizer.step() model.zero_grad() train_bar.set_postfix(loss=train_loss/(i+1), mae=train_mae/(i+1)) optimizer.step() model.zero_grad() # print("epoch:",epoch,"loss:",epoch_loss/len(dataloader)) torch.save(model.state_dict(),'./checkpoints/epoch_'+str(epoch)+".pth") # testing phase model.eval() val_loss = 0 val_mae = 0 for i, (img,gt_dmap) in enumerate((val_loader)): img = img.to(args.device) gt_dmap = gt_dmap.to(args.device) # forward propagation et_dmap = model(img) loss = criterion(et_dmap, gt_dmap) val_loss += loss.item() val_mae += abs(et_dmap.data.sum()-gt_dmap.data.sum()).item() del img,gt_dmap,et_dmap if val_mae/len(val_loader) < min_mae: min_mae = val_mae/len(val_loader) min_epoch = epoch # print("epoch:" + str(epoch) + " error:" + str(mae/len(val_loader)) + " min_mae:"+str(min_mae) + " min_epoch:"+str(min_epoch)) wandb.log({"loss/train": train_loss/len(train_loader), "mae/train": train_mae/len(train_loader), "loss/val": val_loss/len(val_loader), "mae/val": val_mae/len(val_loader), }, commit=False) # show an image index = random.randint(0, len(val_loader)-1) img, gt_dmap = val_dataset[index] gt_dmap = gt_dmap.squeeze(0).detach().cpu().numpy() wandb.log({"image/img": [wandb.Image(img)]}, commit=False) wandb.log({"image/gt_dmap": [wandb.Image(gt_dmap/(gt_dmap.max())*255, caption=str(gt_dmap.sum()))]}, commit=False) img = img.unsqueeze(0).to(args.device) et_dmap = model(img) et_dmap = et_dmap.squeeze(0).detach().cpu().numpy() wandb.log({"image/et_dmap": [wandb.Image(et_dmap/(et_dmap.max())*255, caption=str(et_dmap.sum()))]}) import time print(time.strftime('%Y.%m.%d %H:%M:%S',time.localtime(time.time())))
epoch_loss=0 for i,(img,gt_dmap) in enumerate(tqdm(train_loader)): img=img.to(device) gt_dmap=gt_dmap.to(device) # forward propagation et_dmap=model(img) # calculate loss loss=criterion(et_dmap,gt_dmap) epoch_loss+=loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # print("epoch:",epoch,"loss:",epoch_loss/len(dataloader)) epoch_list.append(epoch) train_loss_list.append(epoch_loss/len(train_loader)) torch.save(model.state_dict(),'./checkpoints/epoch_'+str(epoch)+".pth") # testing phase model.eval() mae=0 for i,(img,gt_dmap) in enumerate(tqdm(test_loader)): img=img.to(device) gt_dmap=gt_dmap.to(device) # forward propagation et_dmap=model(img) mae+=abs(et_dmap.data.sum()-gt_dmap.data.sum()).item() del img,gt_dmap,et_dmap if mae/len(test_loader)<min_mae: min_mae=mae/len(test_loader) min_epoch=epoch test_error_list.append(mae/len(test_loader))
epoch_loss = 0 for i, (img, gt_dmap) in enumerate(tqdm(train_loader)): img = img.to(device) gt_dmap = gt_dmap.to(device) # forward propagation et_dmap = model(img) # calculate loss loss = criterion(et_dmap, gt_dmap) epoch_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() # print("epoch:",epoch,"loss:",epoch_loss/len(dataloader)) epoch_list.append(epoch) train_loss_list.append(epoch_loss / len(train_loader)) torch.save(model.state_dict(), './checkpoints/epoch_' + str(epoch) + ".pth") # testing phase model.eval() mae = 0 for i, (img, gt_dmap) in enumerate(tqdm(test_loader)): img = img.to(device) gt_dmap = gt_dmap.to(device) # forward propagation et_dmap = model(img) mae += abs(et_dmap.data.sum() - gt_dmap.data.sum()).item() del img, gt_dmap, et_dmap if mae / len(test_loader) < min_mae: min_mae = mae / len(test_loader) min_epoch = epoch
epoch_list.append(epoch) train_loss_list.append(epoch_loss / len(train_loader)) # testing phase model.eval() mae = 0 for i, (img, gt_dmap) in enumerate(tqdm(test_loader)): img = img.to(device) gt_dmap = gt_dmap.to(device) # forward propagation et_dmap = model(img) mae += abs(et_dmap.data.sum() - gt_dmap.data.sum()).item() del img, gt_dmap, et_dmap if mae / len(test_loader) < min_mae: torch.save( model.state_dict(), './checkpoints/' + str(args.data_dir.split('/')[-1]) + '_epoch_' + str(epoch) + ".pth") min_mae = mae / len(test_loader) min_epoch = epoch test_error_list.append(mae / len(test_loader)) print("epoch:" + str(epoch) + " error:" + str(mae / len(test_loader)) + " min_mae:" + str(min_mae) + " min_epoch:" + str(min_epoch)) vis.line(win=1, X=epoch_list, Y=train_loss_list, opts=dict(title='train_loss')) vis.line(win=2, X=epoch_list, Y=test_error_list,