def main(args): global best_final_acc global idx global output_res output_res = args.out_res # 2020.3.2 global REDRAW # 2020.3.4 # if you do type arg.resume # args.checkpoint would be derived from arg.resume if args.pre_train: # pre train lr = 5e-4 args.lr = 5e-5 if args.resume != '' and args.pre_train == False: args.checkpoint = ('/').join(args.resume.split('/')[:2]) if args.relabel == True: args.test_batch = 1 if args.test == True: args.train_batch = 4 args.test_batch = 4 args.epochs = 10 if args.evaluate and args.relabel == False: args.test_batch = 4 # write line-chart and stop program if args.write: draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt')) return idx = [1] # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) # create model njoints = datasets.__dict__[args.dataset].njoints print("==> creating model '{}', stacks={}, blocks={}".format( args.arch, args.stacks, args.blocks)) model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=njoints, resnet_layers=args.resnet_layers) model = torch.nn.DataParallel(model).to(device) # define loss function (criterion) and optimizer criterion_iou = losses.IoULoss().to(device) criterion_bce = losses.BCELoss().to(device) criterion_focal = losses.FocalLoss().to(device) criterions = [criterion_iou, criterion_bce, criterion_focal] if args.solver == 'rms': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.solver == 'adam': optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, ) else: print('Unknown solver: {}'.format(args.solver)) assert False # optionally resume from a checkpoint title = args.dataset + ' ' + args.arch if args.pre_train: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) # start from epoch 0 args.start_epoch = 0 best_final_acc = 0 model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Attention Loss', 'Val Attention Loss', 'Val Attention Loss', \ 'Train Region Loss', 'Val Region Loss', 'Val Region IoU', \ 'Train Existence Acc', 'Val Existence Loss', 'Val Existence Acc', 'Val final acc']) else: print("=> no checkpoint found at '{}'".format(args.resume)) elif args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) # args.start_epoch = checkpoint['epoch'] # best_iou = checkpoint['best_iou'] # start from epoch 0 args.start_epoch = 0 best_final_acc = 0 model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Attention Loss', 'Val Attention Loss', 'Val Attention IoU', \ 'Train Region Loss', 'Val Region Loss', 'Val Region IoU', \ 'Train Existence Acc', 'Val Existence Loss', 'Val Existence Acc', 'Val final acc']) print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # create data loader train_dataset = datasets.__dict__[args.dataset]( is_train=True, **vars(args)) #-> depend on args.dataset to replace with datasets train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) ''' for i, (input, input_depth, target_heatmap, target_mask, target_label, meta) in enumerate(train_loader): print(len(input)) print(input[0].shape) print(input_depth[0].shape) print(target_heatmap[0].shape) print(target_mask[0].shape) print(target_label[0].shape) return ''' val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args)) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) # redraw training / test label : global RELABEL if args.relabel: RELABEL = True if args.evaluate: print('\nRelabel val label') val_att_loss, val_att_iou, val_region_loss, val_region_iou, \ val_existence_loss, val_existence_acc , val_final_acc \ = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip) print("Val final acc: %.3f" % (val_final_acc)) # Because test and val are all considered -> iou is uesless # print("Val IoU: %.3f" % (iou)) return # evaluation only global JUST_EVALUATE JUST_EVALUATE = False if args.evaluate: print('\nEvaluation only') JUST_EVALUATE = True val_att_loss, val_att_iou, val_region_loss, val_region_iou, \ val_existence_loss, val_existence_acc , val_final_acc \ = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip) print("Val final acc: %.3f" % (val_final_acc)) # print( val_att_loss, val_att_iou, val_region_loss, val_region_iou, \ # val_existence_loss, val_existence_acc , val_final_acc) return ## backup when training starts code_backup_dir = 'code_backup' mkdir_p(os.path.join(args.checkpoint, code_backup_dir)) os.system( 'cp ../affordance/models/hourglass_final.py %s/%s/hourglass_final.py' % (args.checkpoint, code_backup_dir)) os.system( 'cp ../affordance/datasets/sad_attention.py %s/%s/sad_attention.py' % (args.checkpoint, code_backup_dir)) this_file_name = os.path.split(os.path.abspath(__file__))[1] os.system('cp ./%s %s' % (this_file_name, os.path.join(args.checkpoint, code_backup_dir, this_file_name))) # train and eval lr = args.lr for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # decay sigma if args.sigma_decay > 0: train_loader.dataset.sigma *= args.sigma_decay val_loader.dataset.sigma *= args.sigma_decay # train for one epoch train_att_loss, train_region_loss, train_existence_loss \ = train(train_loader, model, criterions, optimizer, args.debug, args.flip) # evaluate on validation set val_att_loss, val_att_iou, val_region_loss, val_region_iou, \ val_existence_loss, val_existence_acc , val_final_acc \ = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip) print("Val region IoU: %.3f" % (val_region_iou)) print("Val label acc: %.3f" % (val_existence_acc)) val_final_acc = val_region_iou + val_existence_acc # append logger file logger.append([epoch + 1, lr, train_att_loss, val_att_loss, val_att_iou, \ train_region_loss, val_region_loss, val_region_iou, \ train_existence_loss, val_existence_loss, val_existence_acc, val_final_acc]) # remember best acc and save checkpoint is_best_acc = val_final_acc > best_final_acc best_final_acc = max(val_final_acc, best_final_acc) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_iou': best_final_acc, 'optimizer': optimizer.state_dict(), }, is_best_acc, checkpoint=args.checkpoint, snapshot=args.snapshot) logger.close() print("Best val final acc = %.3f" % (best_final_acc))
def main(args): global best_iou global idx global output_res output_res = args.out_res # 2020.3.2 global REDRAW # 2020.3.4 # if you do type arg.resume # args.checkpoint would be derived from arg.resume if args.resume != '': args.checkpoint = ('/').join(args.resume.split('/')[:2]) if args.relabel == True: args.test_batch = 1 elif args.test == True: args.train_batch = 1 args.test_batch = 1 args.epochs = 20 # args.train_batch = 2 # args.test_batch = 2 # args.epochs = 10 # write line-chart and stop program if args.write: draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt')) return # idx is the index of joints used to compute accuracy if args.dataset in ['mpii', 'lsp']: idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16] elif args.dataset == 'coco': idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] elif args.dataset == 'sad' or args.dataset == 'sad_step_2' or args.dataset == 'sad_step_2_eval': idx = [1] # support affordance else: print("Unknown dataset: {}".format(args.dataset)) assert False # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) # create model njoints = datasets.__dict__[args.dataset].njoints model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=njoints, resnet_layers=args.resnet_layers) model = torch.nn.DataParallel(model).to(device) # define loss function (criterion) and optimizer criterion = losses.BCELoss().to(device) if args.solver == 'rms': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.solver == 'adam': optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, ) else: print('Unknown solver: {}'.format(args.solver)) assert False # optionally resume from a checkpoint title = args.dataset + ' ' + args.arch if args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_iou = checkpoint['best_iou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val Acc']) print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # create data loader train_dataset = datasets.__dict__[args.dataset]( is_train=True, **vars(args)) #-> depend on args.dataset to replace with datasets train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) # for i, (input, input_depth, input_mask, target, meta) in enumerate(train_loader): # print(len(input)) # print(input[0].shape) # print(input_mask[0].shape) # print(target[0].shape) # return val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args)) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) # redraw training / test label : global RELABEL if args.relabel: RELABEL = True if args.evaluate: print('\nRelabel val label') loss, acc = validate(val_loader, model, criterion, njoints, args.checkpoint, args.debug, args.flip) print("Val acc: %.3f" % (acc)) return # evaluation only global JUST_EVALUATE JUST_EVALUATE = False if args.evaluate: print('\nEvaluation only') JUST_EVALUATE = True loss, acc = validate(val_loader, model, criterion, njoints, args.checkpoint, args.debug, args.flip) print("Val acc: %.3f" % (acc)) return ## backup when training starts code_backup_dir = 'code_backup' mkdir_p(os.path.join(args.checkpoint, code_backup_dir)) os.system( 'cp ../affordance/models/affordance_classification.py %s/%s/affordance_classification.py' % (args.checkpoint, code_backup_dir)) os.system('cp ../affordance/models/convlstm.py %s/%s/convlstm.py' % (args.checkpoint, code_backup_dir)) os.system( 'cp ../affordance/datasets/sad_step_2_eval.py %s/%s/sad_step_2_eval.py' % (args.checkpoint, code_backup_dir)) this_file_name = os.path.split(os.path.abspath(__file__))[1] os.system('cp ./%s %s' % (this_file_name, os.path.join(args.checkpoint, code_backup_dir, this_file_name))) # train and eval lr = args.lr for epoch in range(args.start_epoch, args.epochs): lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma) print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) # decay sigma if args.sigma_decay > 0: train_loader.dataset.sigma *= args.sigma_decay val_loader.dataset.sigma *= args.sigma_decay # train for one epoch train_loss = train(train_loader, model, criterion, optimizer, args.debug, args.flip) # evaluate on validation set valid_loss, valid_acc = validate(val_loader, model, criterion, njoints, args.checkpoint, args.debug, args.flip) print("Val acc: %.3f" % (valid_acc)) # append logger file logger.append([epoch + 1, lr, train_loss, valid_loss, valid_acc]) # remember best acc and save checkpoint is_best_iou = valid_acc > best_iou best_iou = max(valid_acc, best_iou) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_iou': best_iou, 'optimizer': optimizer.state_dict(), }, is_best_iou, checkpoint=args.checkpoint, snapshot=args.snapshot) logger.close() print("Best acc = %.3f" % (best_iou))
def main(args): global best_iou global idx global output_res output_res = args.out_res # 2020.3.2 global REDRAW # 2020.3.4 # if you do type arg.resume # args.checkpoint would be derived from arg.resume if args.resume != '': args.checkpoint = ('/').join(args.resume.split('/')[:2]) if args.relabel == True: args.test_batch = 1 elif args.test == True: # args.train_batch = 4 # args.test_batch = 4 # args.epochs = 20 args.train_batch = 2 args.test_batch = 2 args.epochs = 10 # write line-chart and stop program if args.write: draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt')) return # idx is the index of joints used to compute accuracy if args.dataset in ['mpii', 'lsp']: idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16] elif args.dataset == 'coco': idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] elif args.dataset == 'sad' or args.dataset == 'sad_step_2' or args.dataset == 'sad_step_2_eval': idx = [1] # support affordance else: print("Unknown dataset: {}".format(args.dataset)) assert False # create checkpoint dir if not isdir(args.checkpoint): mkdir_p(args.checkpoint) # create model njoints = datasets.__dict__[args.dataset].njoints model = models.__dict__[args.arch](num_stacks=args.stacks, num_blocks=args.blocks, num_classes=njoints, resnet_layers=args.resnet_layers) model = torch.nn.DataParallel(model).to(device) # define loss function (criterion) and optimizer criterion = losses.BCELoss().to(device) if args.solver == 'rms': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.solver == 'adam': optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, ) else: print('Unknown solver: {}'.format(args.solver)) assert False # optionally resume from a checkpoint title = args.dataset + ' ' + args.arch if args.resume: if isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_iou = checkpoint['best_iou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger = Logger(join(args.checkpoint, 'log.txt'), title=title, resume=True) else: print("=> no checkpoint found at '{}'".format(args.resume)) else: logger = Logger(join(args.checkpoint, 'log.txt'), title=title) logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val Acc']) print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) # create data loader train_dataset = datasets.__dict__[args.dataset]( is_train=True, **vars(args)) #-> depend on args.dataset to replace with datasets train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers, pin_memory=True) val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args)) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=True) # redraw training / test label : global RELABEL if args.relabel: RELABEL = True if args.evaluate: print('\nRelabel val label') loss, acc = validate(val_loader, model, criterion, njoints, args.checkpoint, args.debug, args.flip) print("Val acc: %.3f" % (acc)) return # evaluation only global JUST_EVALUATE JUST_EVALUATE = False if args.evaluate: print('\nEvaluation only') JUST_EVALUATE = True loss, acc = validate(val_loader, model, criterion, njoints, args.checkpoint, args.debug, args.flip) print("Val acc: %.3f" % (acc)) return