def main(args): wandb.init(project="crowd", config=args) args = wandb.config # print(args) # vis=visdom.Visdom() torch.cuda.manual_seed(args.seed) model=CANNet().to(args.device) criterion=nn.MSELoss(reduction='sum').to(args.device) # optimizer=torch.optim.SGD(model.parameters(), args.lr, # momentum=args.momentum, # weight_decay=0) optimizer=torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.decay) train_dataset = CrowdDataset(args.train_image_root, args.train_dmap_root, gt_downsample=8, phase='train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_dataset = CrowdDataset(args.val_image_root, args.val_dmap_root, gt_downsample=8, phase='test') val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) if not os.path.exists('./checkpoints'): os.mkdir('./checkpoints') min_mae = 10000 min_epoch = 0 for epoch in tqdm(range(0, args.epochs)): # training phase model.train() model.zero_grad() train_loss = 0 train_mae = 0 train_bar = tqdm(train_loader) for i, (img,gt_dmap) in enumerate(train_bar): # print(img.shape, gt_dmap.shape) img = img.to(args.device) gt_dmap = gt_dmap.to(args.device) # forward propagation et_dmap = model(img) # calculate loss # print(et_dmap.shape, gt_dmap.shape) loss = criterion(et_dmap, gt_dmap) train_loss += loss.item() train_mae += abs(et_dmap.data.sum()-gt_dmap.data.sum()).item() loss = loss/args.gradient_accumulation_steps loss.backward() if (i+1)%args.gradient_accumulation_steps == 0: optimizer.step() model.zero_grad() train_bar.set_postfix(loss=train_loss/(i+1), mae=train_mae/(i+1)) optimizer.step() model.zero_grad() # print("epoch:",epoch,"loss:",epoch_loss/len(dataloader)) torch.save(model.state_dict(),'./checkpoints/epoch_'+str(epoch)+".pth") # testing phase model.eval() val_loss = 0 val_mae = 0 for i, (img,gt_dmap) in enumerate((val_loader)): img = img.to(args.device) gt_dmap = gt_dmap.to(args.device) # forward propagation et_dmap = model(img) loss = criterion(et_dmap, gt_dmap) val_loss += loss.item() val_mae += abs(et_dmap.data.sum()-gt_dmap.data.sum()).item() del img,gt_dmap,et_dmap if val_mae/len(val_loader) < min_mae: min_mae = val_mae/len(val_loader) min_epoch = epoch # print("epoch:" + str(epoch) + " error:" + str(mae/len(val_loader)) + " min_mae:"+str(min_mae) + " min_epoch:"+str(min_epoch)) wandb.log({"loss/train": train_loss/len(train_loader), "mae/train": train_mae/len(train_loader), "loss/val": val_loss/len(val_loader), "mae/val": val_mae/len(val_loader), }, commit=False) # show an image index = random.randint(0, len(val_loader)-1) img, gt_dmap = val_dataset[index] gt_dmap = gt_dmap.squeeze(0).detach().cpu().numpy() wandb.log({"image/img": [wandb.Image(img)]}, commit=False) wandb.log({"image/gt_dmap": [wandb.Image(gt_dmap/(gt_dmap.max())*255, caption=str(gt_dmap.sum()))]}, commit=False) img = img.unsqueeze(0).to(args.device) et_dmap = model(img) et_dmap = et_dmap.squeeze(0).detach().cpu().numpy() wandb.log({"image/et_dmap": [wandb.Image(et_dmap/(et_dmap.max())*255, caption=str(et_dmap.sum()))]}) import time print(time.strftime('%Y.%m.%d %H:%M:%S',time.localtime(time.time())))
lr = 1e-7 batch_size = 1 momentum = 0.95 epochs = 20000 steps = [-1, 1, 100, 150] scales = [1, 1, 1, 1] workers = 4 seed = time.time() print_freq = 30 vis = visdom.Visdom() device = torch.device(gpu_or_cpu) torch.cuda.manual_seed(seed) model = CANNet().to(device) criterion = nn.MSELoss(size_average=False).to(device) optimizer = torch.optim.SGD(model.parameters(), lr, momentum=momentum, weight_decay=0) # optimizer=torch.optim.Adam(model.parameters(),lr) train_dataset = CrowdDataset(train_image_root, train_dmap_root, gt_downsample=8) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) test_dataset = CrowdDataset(test_image_root, test_dmap_root, gt_downsample=8) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,