def main(args): # build train and val set train_dir = args.train_dir val_dir = args.val_dir config = Config(args.config) cudnn.benchmark = True # train train_loader = torch.utils.data.DataLoader( lsp_lspet_data.LSP_Data('lspet', train_dir, 8, Mytransforms.Compose([Mytransforms.RandomResized(), Mytransforms.RandomRotate(40), Mytransforms.RandomCrop(368), Mytransforms.RandomHorizontalFlip(), ])), batch_size=config.batch_size, shuffle=True, num_workers=config.workers, pin_memory=True) # val if args.val_dir is not None and config.test_interval != 0: # val val_loader = torch.utils.data.DataLoader( lsp_lspet_data.LSP_Data('lsp', val_dir, 8, Mytransforms.Compose([Mytransforms.TestResized(368), ])), batch_size=config.batch_size, shuffle=False, num_workers=config.workers, pin_memory=True) # build model model = MSBR(config=config, args=args, k=14, stages=config.stages) model.build_nets() return model, train_loader, val_loader
def train_val(model, args): train_dir = args.train_dir val_dir = args.val_dir config = Config(args.config) cudnn.benchmark = True # train train_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data( 'lspet', train_dir, 8, Mytransforms.Compose([ Mytransforms.RandomResized(), Mytransforms.RandomRotate(40), Mytransforms.RandomCrop(368), Mytransforms.RandomHorizontalFlip(), ])), batch_size=config.batch_size, shuffle=True, num_workers=config.workers, pin_memory=True) # val if args.val_dir is not None and config.test_interval != 0: # val val_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data( 'lsp', val_dir, 8, Mytransforms.Compose([ Mytransforms.TestResized(368), ])), batch_size=config.batch_size, shuffle=True, num_workers=config.workers, pin_memory=True) if args.gpu[0] < 0: criterion = nn.MSELoss() else: criterion = nn.MSELoss().cuda() params, multiple = get_parameters(model, config, True) # params, multiple = get_parameters(model, config, False) optimizer = torch.optim.SGD(params, config.base_lr, momentum=config.momentum, weight_decay=config.weight_decay) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_list = [AverageMeter() for i in range(6)] end = time.time() iters = config.start_iters best_model = config.best_model heat_weight = 46 * 46 * 15 / 1.0 losstracker1 = [] losstracker2 = [] losstracker3 = [] losstracker4 = [] losstracker5 = [] losstracker6 = [] while iters < config.max_iter: for i, (input, heatmap, centermap) in enumerate(train_loader): learning_rate = adjust_learning_rate( optimizer, iters, config.base_lr, policy=config.lr_policy, policy_parameter=config.policy_parameter, multiple=multiple) data_time.update(time.time() - end) if args.gpu[0] >= 0: heatmap = heatmap.cuda(async=True) centermap = centermap.cuda(async=True) input_var = torch.autograd.Variable(input) heatmap_var = torch.autograd.Variable(heatmap) centermap_var = torch.autograd.Variable(centermap) heat1, heat2, heat3, heat4, heat5, heat6 = model( input_var, centermap_var) loss1 = criterion(heat1, heatmap_var) * heat_weight loss2 = criterion(heat2, heatmap_var) * heat_weight loss3 = criterion(heat3, heatmap_var) * heat_weight loss4 = criterion(heat4, heatmap_var) * heat_weight loss5 = criterion(heat5, heatmap_var) * heat_weight loss6 = criterion(heat6, heatmap_var) * heat_weight loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 #print(input.size(0).item()) losses.update(loss.item(), input.size(0)) for cnt, l in enumerate([loss1, loss2, loss3, loss4, loss5, loss6]): losses_list[cnt].update(l.item(), input.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() iters += 1 if iters % config.display == 0: print( 'Train Iteration: {0}\t' 'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n' 'Learning rate = {2}\n' 'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format( iters, config.display, learning_rate, batch_time=batch_time, data_time=data_time, loss=losses)) for cnt in range(0, 6): print( 'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'. format(cnt + 1, loss1=losses_list[cnt])) print( time.strftime( '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n', time.localtime())) batch_time.reset() data_time.reset() losses.reset() for cnt in range(6): losses_list[cnt].reset() save_checkpoint({ 'iter': iters, 'state_dict': model.state_dict(), }, 0, args.model_name) # val if args.val_dir is not None and config.test_interval != 0 and iters % config.test_interval == 0: model.eval() for j, (input, heatmap, centermap) in enumerate(val_loader): if args.cuda[0] >= 0: heatmap = heatmap.cuda(async=True) centermap = centermap.cuda(async=True) input_var = torch.autograd.Variable(input) heatmap_var = torch.autograd.Variable(heatmap) centermap_var = torch.autograd.Variable(centermap) heat1, heat2, heat3, heat4, heat5, heat6 = model( input_var, centermap_var) loss1 = criterion(heat1, heatmap_var) * heat_weight loss2 = criterion(heat2, heatmap_var) * heat_weight loss3 = criterion(heat3, heatmap_var) * heat_weight loss4 = criterion(heat4, heatmap_var) * heat_weight loss5 = criterion(heat5, heatmap_var) * heat_weight loss6 = criterion(heat6, heatmap_var) * heat_weight loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 losses.update(loss.data[0], input.size(0)) for cnt, l in enumerate( [loss1, loss2, loss3, loss4, loss5, loss6]): losses_list[cnt].update(l.data[0], input.size(0)) batch_time.update(time.time() - end) end = time.time() is_best = losses.avg < best_model best_model = min(best_model, losses.avg) save_checkpoint( { 'iter': iters, 'state_dict': model.state_dict(), }, is_best, args.model_name) if j % config.display == 0: print( 'Test Iteration: {0}\t' 'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n' 'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'. format(j, config.display, batch_time=batch_time, data_time=data_time, loss=losses)) for cnt in range(0, 6): print( 'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t' .format(cnt + 1, loss1=losses_list[cnt])) print( time.strftime( '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n', time.localtime())) batch_time.reset() losses.reset() for cnt in range(6): losses_list[cnt].reset() losstracker1.append(loss1) losstracker2.append(loss2) losstracker3.append(loss3) losstracker4.append(loss4) losstracker5.append(loss5) losstracker6.append(loss6) model.train() np.save('loss1', np.asarray(losstracker1)) np.save('loss2', np.asarray(losstracker2)) np.save('loss3', np.asarray(losstracker3)) np.save('loss4', np.asarray(losstracker4)) np.save('loss5', np.asarray(losstracker5)) np.save('loss6', np.asarray(losstracker6))
def train_val(model, args): train_dir = args.train_dir val_dir = args.val_dir config = Config(args.config) cudnn.benchmark = True #lspet dataset contains 10000 images, lsp dataset contains 2000 images. # train train_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data( 'lspet', train_dir, 8, Mytransforms.Compose([ Mytransforms.RandomResized(), Mytransforms.RandomRotate(40), Mytransforms.RandomCrop(368), Mytransforms.RandomHorizontalFlip(), ])), batch_size=config.batch_size, shuffle=True, num_workers=config.workers, pin_memory=True) # val if args.val_dir is not None and config.test_interval != 0: # val val_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data( 'lsp', val_dir, 8, Mytransforms.Compose([ Mytransforms.TestResized(368), ])), batch_size=config.batch_size, shuffle=True, num_workers=config.workers, pin_memory=True) criterion = nn.MSELoss().cuda() params, multiple = get_parameters(model, config, False) optimizer = torch.optim.SGD(params, config.base_lr, momentum=config.momentum, weight_decay=config.weight_decay) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_list = [AverageMeter() for i in range(6)] end = time.time() iters = config.start_iters best_model = config.best_model heat_weight = 46 * 46 * 15 / 1.0 while iters < config.max_iter: #train_loader가 한번 불러오면 i는 1증가, input은 16개씩 가져옴 for i, (input, heatmap, centermap, img_path) in enumerate(train_loader): learning_rate = adjust_learning_rate( optimizer, iters, config.base_lr, policy=config.lr_policy, policy_parameter=config.policy_parameter, multiple=multiple) data_time.update(time.time() - end) heatmap = heatmap.cuda(async=True) #print(heatmap) #sys.exit(1) centermap = centermap.cuda(async=True) input_var = torch.autograd.Variable(input) heatmap_var = torch.autograd.Variable(heatmap) centermap_var = torch.autograd.Variable(centermap) heat1, heat2, heat3, heat4, heat5, heat6 = model( input_var, centermap_var) loss1 = criterion(heat1, heatmap_var) * heat_weight loss2 = criterion(heat2, heatmap_var) * heat_weight loss3 = criterion(heat3, heatmap_var) * heat_weight loss4 = criterion(heat4, heatmap_var) * heat_weight loss5 = criterion(heat5, heatmap_var) * heat_weight loss6 = criterion(heat6, heatmap_var) * heat_weight loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 losses.update(loss.data[0], input.size(0)) for cnt, l in enumerate([loss1, loss2, loss3, loss4, loss5, loss6]): losses_list[cnt].update(l.data[0], input.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() iters += 1 #print(i,'\n') if iters % config.display == 0: print( 'Train Iteration: {0}\t' 'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n' 'Learning rate = {2}\n' 'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format( iters, config.display, learning_rate, batch_time=batch_time, data_time=data_time, loss=losses)) for cnt in range(0, 6): print( 'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'. format(cnt + 1, loss1=losses_list[cnt])) print( time.strftime( '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n', time.localtime())) ############# image write ################## for cnt in range(config.batch_size): kpts = get_kpts(heat6[cnt], img_h=368.0, img_w=368.0) draw_paint(img_path[cnt], kpts, i, cnt) ####################################################### batch_time.reset() data_time.reset() losses.reset() for cnt in range(6): losses_list[cnt].reset() save_checkpoint({ 'iter': iters, 'state_dict': model.state_dict(), }, 0, args.model_name) # val if args.val_dir is not None and config.test_interval != 0 and iters % config.test_interval == 0: model.eval() for j, (input, heatmap, centermap) in enumerate(val_loader): heatmap = heatmap.cuda(async=True) centermap = centermap.cuda(async=True) input_var = torch.autograd.Variable(input) heatmap_var = torch.autograd.Variable(heatmap) centermap_var = torch.autograd.Variable(centermap) heat1, heat2, heat3, heat4, heat5, heat6 = model( input_var, centermap_var) loss1 = criterion(heat1, heatmap_var) * heat_weight loss2 = criterion(heat2, heatmap_var) * heat_weight loss3 = criterion(heat3, heatmap_var) * heat_weight loss4 = criterion(heat4, heatmap_var) * heat_weight loss5 = criterion(heat5, heatmap_var) * heat_weight loss6 = criterion(heat6, heatmap_var) * heat_weight loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6 losses.update(loss.data[0], input.size(0)) for cnt, l in enumerate( [loss1, loss2, loss3, loss4, loss5, loss6]): losses_list[cnt].update(l.data[0], input.size(0)) batch_time.update(time.time() - end) end = time.time() is_best = losses.avg < best_model best_model = min(best_model, losses.avg) save_checkpoint( { 'iter': iters, 'state_dict': model.state_dict(), }, is_best, args.model_name) if j % config.display == 0: print( 'Test Iteration: {0}\t' 'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n' 'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'. format(j, config.display, batch_time=batch_time, data_time=data_time, loss=losses)) for cnt in range(0, 6): print( 'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t' .format(cnt + 1, loss1=losses_list[cnt])) print( time.strftime( '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n', time.localtime())) batch_time.reset() losses.reset() for cnt in range(6): losses_list[cnt].reset() model.train()