def main(args): global best_iou global idx global output_res output_res = args.out_res # 2020.3.2 global REDRAW # 2020.3.4 # if you do type arg.resume # args.checkpoint would be derived from arg.resume if args.pre_train: # pre train lr = 5e-4 args.lr = 5e-5 if args.resume != '' and args.pre_train == False: args.checkpoint = ('/').join(args.resume.split('/')[:2]) if args.relabel == True: args.test_batch = 1 if args.test == True: args.train_batch = 4 args.test_batch = 4 args.epochs = 20 if args.evaluate and args.relabel == False: args.test_batch = 1 # write line-chart and stop program if args.write: draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt')) return # idx is the index of joints used to compute accuracy # if args.dataset in ['mpii', 'lsp']: # idx = [1,2,3,4,5,6,11,12,15,16] # elif args.dataset == 'coco': # idx = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17] # elif args.dataset == 'sad' or args.dataset == 'sad_step_1': # idx = [1] # support affordance # else: # print("Unknown dataset: {}".format(args.dataset)) # assert False idx = [1] # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) # create model njoints = datasets.__dict__[args.dataset].njoints print("==> creating model '{}', stacks={}, blocks={}".format( args.arch, args.stacks, args.blocks)) model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=njoints, resnet_layers=args.resnet_layers) # 2020.6.7 # freeze feature extraction and first one hg model paras # freeze_list = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, \ # model.hg[0], model.res[0], model.fc[0], \ # model.score[0],model.fc_[0], model.score_[0]] # for freeze_layer in freeze_list : # for param in freeze_layer.parameters(): # param.requires_grad = False model = torch.nn.DataParallel(model).to(device) # define loss function (criterion) and optimizer criterion = losses.IoULoss().to(device) if args.solver == 'rms': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.solver == 'adam': optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, ) # optimizer = torch.optim.Adam( # filter(lambda p: p.requires_grad, model.parameters()), # lr=args.lr, # ) else: print('Unknown solver: {}'.format(args.solver)) assert False # optionally resume from a checkpoint title = args.dataset + ' ' + args.arch if args.pre_train: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) # start from epoch 0 args.start_epoch = 0 best_iou = 0 model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names( ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val IoU']) else: print("=> no checkpoint found at '{}'".format(args.resume)) elif args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) # args.start_epoch = checkpoint['epoch'] # best_iou = checkpoint['best_iou'] # start from epoch 0 args.start_epoch = 0 best_iou = 0 model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val IoU']) print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # create data loader train_dataset = datasets.__dict__[args.dataset]( is_train=True, **vars(args)) #-> depend on args.dataset to replace with datasets train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) ''' for i, (input, input_depth, target, meta) in enumerate(train_loader): print(len(input)) print(input[0].shape) print(input_depth[0].shape) print(target[0].shape) return ''' val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args)) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) # redraw training / test label : global RELABEL if args.relabel: RELABEL = True if args.evaluate: print('\nRelabel val label') new_checkpoint = 'checkpoint_0701_bbox_hide' mkdir_p(new_checkpoint) loss, iou, predictions = validate(val_loader, model, criterion, njoints, new_checkpoint, args.debug, args.flip) # Because test and val are all considered -> iou is uesless # print("Val IoU: %.3f" % (iou)) return # evaluation only global JUST_EVALUATE JUST_EVALUATE = False if args.evaluate: print('\nEvaluation only') if args.debug: print('Draw pred /gt heatmap') JUST_EVALUATE = True new_checkpoint = 'checkpoint_0701_bbox_hide' mkdir_p(new_checkpoint) loss, iou, predictions = validate(val_loader, model, criterion, njoints, new_checkpoint, args.debug, args.flip) print("Val IoU: %.3f" % (iou)) return ## backup when training starts code_backup_dir = 'code_backup' mkdir_p(os.path.join(args.checkpoint, code_backup_dir)) os.system('cp ../affordance/models/hourglass.py %s/%s/hourglass.py' % (args.checkpoint, code_backup_dir)) os.system('cp ../affordance/datasets/sad.py %s/%s/sad.py' % (args.checkpoint, code_backup_dir)) this_file_name = os.path.split(os.path.abspath(__file__))[1] os.system('cp ./%s %s' % (this_file_name, os.path.join(args.checkpoint, code_backup_dir, this_file_name))) # train and eval lr = args.lr for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # decay sigma if args.sigma_decay > 0: train_loader.dataset.sigma *= args.sigma_decay val_loader.dataset.sigma *= args.sigma_decay # train for one epoch train_loss = train(train_loader, model, criterion, optimizer, args.debug, args.flip) # evaluate on validation set valid_loss, valid_iou, predictions = validate(val_loader, model, criterion, njoints, arg.checkpoint, args.debug, args.flip) print("Val IoU: %.3f" % (valid_iou)) # append logger file logger.append([epoch + 1, lr, train_loss, valid_loss, valid_iou]) # remember best acc and save checkpoint is_best_iou = valid_iou > best_iou best_iou = max(valid_iou, best_iou) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_iou': best_iou, 'optimizer': optimizer.state_dict(), }, is_best_iou, checkpoint=args.checkpoint, snapshot=args.snapshot) logger.close() print("Best iou = %.3f" % (best_iou))
def main(args): global best_final_acc global idx global output_res output_res = args.out_res # 2020.3.2 global REDRAW # 2020.3.4 # if you do type arg.resume # args.checkpoint would be derived from arg.resume if args.pre_train: # pre train lr = 5e-4 args.lr = 5e-5 if args.resume != '' and args.pre_train == False: args.checkpoint = ('/').join(args.resume.split('/')[:2]) if args.relabel == True: args.test_batch = 1 if args.test == True: args.train_batch = 4 args.test_batch = 4 args.epochs = 10 if args.evaluate and args.relabel == False: args.test_batch = 4 # write line-chart and stop program if args.write: draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt')) return idx = [1] # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) # create model njoints = datasets.__dict__[args.dataset].njoints print("==> creating model '{}', stacks={}, blocks={}".format( args.arch, args.stacks, args.blocks)) model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=njoints, resnet_layers=args.resnet_layers) model = torch.nn.DataParallel(model).to(device) # define loss function (criterion) and optimizer criterion_iou = losses.IoULoss().to(device) criterion_bce = losses.BCELoss().to(device) criterion_focal = losses.FocalLoss().to(device) criterions = [criterion_iou, criterion_bce, criterion_focal] if args.solver == 'rms': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.solver == 'adam': optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, ) else: print('Unknown solver: {}'.format(args.solver)) assert False # optionally resume from a checkpoint title = args.dataset + ' ' + args.arch if args.pre_train: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) # start from epoch 0 args.start_epoch = 0 best_final_acc = 0 model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Attention Loss', 'Val Attention Loss', 'Val Attention Loss', \ 'Train Region Loss', 'Val Region Loss', 'Val Region IoU', \ 'Train Existence Acc', 'Val Existence Loss', 'Val Existence Acc', 'Val final acc']) else: print("=> no checkpoint found at '{}'".format(args.resume)) elif args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) # args.start_epoch = checkpoint['epoch'] # best_iou = checkpoint['best_iou'] # start from epoch 0 args.start_epoch = 0 best_final_acc = 0 model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Attention Loss', 'Val Attention Loss', 'Val Attention IoU', \ 'Train Region Loss', 'Val Region Loss', 'Val Region IoU', \ 'Train Existence Acc', 'Val Existence Loss', 'Val Existence Acc', 'Val final acc']) print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # create data loader train_dataset = datasets.__dict__[args.dataset]( is_train=True, **vars(args)) #-> depend on args.dataset to replace with datasets train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) ''' for i, (input, input_depth, target_heatmap, target_mask, target_label, meta) in enumerate(train_loader): print(len(input)) print(input[0].shape) print(input_depth[0].shape) print(target_heatmap[0].shape) print(target_mask[0].shape) print(target_label[0].shape) return ''' val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args)) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) # redraw training / test label : global RELABEL if args.relabel: RELABEL = True if args.evaluate: print('\nRelabel val label') val_att_loss, val_att_iou, val_region_loss, val_region_iou, \ val_existence_loss, val_existence_acc , val_final_acc \ = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip) print("Val final acc: %.3f" % (val_final_acc)) # Because test and val are all considered -> iou is uesless # print("Val IoU: %.3f" % (iou)) return # evaluation only global JUST_EVALUATE JUST_EVALUATE = False if args.evaluate: print('\nEvaluation only') JUST_EVALUATE = True val_att_loss, val_att_iou, val_region_loss, val_region_iou, \ val_existence_loss, val_existence_acc , val_final_acc \ = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip) print("Val final acc: %.3f" % (val_final_acc)) # print( val_att_loss, val_att_iou, val_region_loss, val_region_iou, \ # val_existence_loss, val_existence_acc , val_final_acc) return ## backup when training starts code_backup_dir = 'code_backup' mkdir_p(os.path.join(args.checkpoint, code_backup_dir)) os.system( 'cp ../affordance/models/hourglass_final.py %s/%s/hourglass_final.py' % (args.checkpoint, code_backup_dir)) os.system( 'cp ../affordance/datasets/sad_attention.py %s/%s/sad_attention.py' % (args.checkpoint, code_backup_dir)) this_file_name = os.path.split(os.path.abspath(__file__))[1] os.system('cp ./%s %s' % (this_file_name, os.path.join(args.checkpoint, code_backup_dir, this_file_name))) # train and eval lr = args.lr for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # decay sigma if args.sigma_decay > 0: train_loader.dataset.sigma *= args.sigma_decay val_loader.dataset.sigma *= args.sigma_decay # train for one epoch train_att_loss, train_region_loss, train_existence_loss \ = train(train_loader, model, criterions, optimizer, args.debug, args.flip) # evaluate on validation set val_att_loss, val_att_iou, val_region_loss, val_region_iou, \ val_existence_loss, val_existence_acc , val_final_acc \ = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip) print("Val region IoU: %.3f" % (val_region_iou)) print("Val label acc: %.3f" % (val_existence_acc)) val_final_acc = val_region_iou + val_existence_acc # append logger file logger.append([epoch + 1, lr, train_att_loss, val_att_loss, val_att_iou, \ train_region_loss, val_region_loss, val_region_iou, \ train_existence_loss, val_existence_loss, val_existence_acc, val_final_acc]) # remember best acc and save checkpoint is_best_acc = val_final_acc > best_final_acc best_final_acc = max(val_final_acc, best_final_acc) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_iou': best_final_acc, 'optimizer': optimizer.state_dict(), }, is_best_acc, checkpoint=args.checkpoint, snapshot=args.snapshot) logger.close() print("Best val final acc = %.3f" % (best_final_acc))