def load_data(): train_transform = transforms.Compose([ transforms.RandomCrop(IM_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.RandomCrop(IM_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) trainset = dataset.LSUN(DATADIR, 'train', train_transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True) print("Train set size: " + str(len(trainset))) valset = dataset.LSUN(DATADIR, 'val', val_transform) valloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, drop_last=True) print("Val set size: " + str(len(valset))) return trainloader, valloader
def train_seg(args): rand_state = np.random.RandomState(1311) torch.manual_seed(1311) device = 'cuda' if (torch.cuda.is_available()) else 'cpu' # We have 2975 images total in the training set, so let's choose 500 for 3 cycles, # 1500 images total (~1/2 of total) images_per_cycle = 150 batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) # Data loading code data_dir = args.data_dir info = json.load(open(join(data_dir, 'info.json'), 'r')) normalize = data_transforms.Normalize(mean=info['mean'], std=info['std']) t = [] if args.random_rotate > 0: t.append(data_transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(data_transforms.RandomScale(args.random_scale)) t.extend([ data_transforms.RandomCrop(crop_size), data_transforms.RandomHorizontalFlip(), data_transforms.ToTensor(), normalize ]) dataset = SegList(data_dir, 'train', data_transforms.Compose(t), list_dir=args.list_dir) training_dataset_no_augmentation = SegList( data_dir, 'train', data_transforms.Compose([data_transforms.ToTensor(), normalize]), list_dir=args.list_dir) unlabeled_idx = list(range(len(dataset))) labeled_idx = [] validation_accuracies = list() validation_mAPs = list() progress = tqdm.tqdm(range(10)) for cycle in progress: single_model = DRNSeg(args.arch, args.classes, None, pretrained=True) if args.pretrained: single_model.load_state_dict(torch.load(args.pretrained)) # Wrap our model in Active Learning Model. if args.use_loss_prediction_al: single_model = ActiveLearning(single_model, global_avg_pool_size=6, fc_width=256) elif args.use_discriminative_al: single_model = DiscriminativeActiveLearning(single_model) optim_parameters = single_model.optim_parameters() model = torch.nn.DataParallel(single_model).cuda() # Don't apply a 'mean' reduction, we need the whole loss vector. criterion = nn.NLLLoss(ignore_index=255, reduction='none') criterion.cuda() if args.choose_images_with_highest_loss: # Choosing images based on the ground truth labels. # We want to check if predicting loss with 100% accuracy would result to # a good active learning algorithm. new_indices, entropies = choose_new_labeled_indices_using_gt( model, cycle, rand_state, unlabeled_idx, training_dataset_no_augmentation, device, criterion, images_per_cycle) else: new_indices, entropies = choose_new_labeled_indices( model, training_dataset_no_augmentation, cycle, rand_state, labeled_idx, unlabeled_idx, device, images_per_cycle, args.use_loss_prediction_al, args.use_discriminative_al, input_pickle_file=None) labeled_idx.extend(new_indices) print("Running on {} labeled images.".format(len(labeled_idx))) if args.output_superannotate_csv_file is not None: # Write image paths to csv file which can be uploaded to annotate.online. write_entropies_csv(training_dataset_no_augmentation, new_indices, entropies, args.output_superannotate_csv_file) train_loader = torch.utils.data.DataLoader(data.Subset( dataset, labeled_idx), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(SegList( data_dir, 'val', data_transforms.Compose([ data_transforms.RandomCrop(crop_size), data_transforms.ToTensor(), normalize, ]), list_dir=args.list_dir), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=True) # define loss function (criterion) and optimizer. optimizer = torch.optim.SGD(optim_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True best_prec1 = 0 best_mAP = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy, num_classes=args.classes, use_loss_prediction_al=args.use_loss_prediction_al) return progress_epoch = tqdm.tqdm(range(start_epoch, args.epochs)) for epoch in progress_epoch: lr = adjust_learning_rate(args, optimizer, epoch) logger.info('Cycle {0} Epoch: [{1}]\tlr {2:.06f}'.format( cycle, epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, eval_score=accuracy, use_loss_prediction_al=args.use_loss_prediction_al, active_learning_lamda=args.lamda) # evaluate on validation set prec1, mAP1 = validate( val_loader, model, criterion, eval_score=accuracy, num_classes=args.classes, use_loss_prediction_al=args.use_loss_prediction_al) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) best_mAP = max(mAP1, best_mAP) checkpoint_path = os.path.join(args.save_path, 'checkpoint_latest.pth.tar') save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'best_mAP': best_mAP, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_iter == 0: history_path = os.path.join( args.save_path, 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)) shutil.copyfile(checkpoint_path, history_path) validation_accuracies.append(best_prec1) validation_mAPs.append(best_mAP) print("{} accuracies: {} mAPs {}".format( "Active Learning" if args.use_loss_prediction_al else "Random", str(validation_accuracies), str(validation_mAPs)))
# t.append(transforms.RandomScale(0)) normalize = transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.]) # t.extend([transforms.RandomCrop((768, 320), 4), # transforms.RandomHorizontalFlip(), # # transforms.ToNumpy(1/255.0), # # transforms.RandomGammaImg((0.7,1.5)), # # transforms.RandomBrightnessImg(0.2), # # transforms.RandomContrastImg((0.8, 1.2)), # # transforms.RandomGaussianNoiseImg(0.02), # # transforms.ToNumpy(255.0), # transforms.ToTensor(convert_pix_range=False), # normalize]) t.extend([ transforms.RandomCrop((768, 320), 4), transforms.RandomHorizontalFlip(), # transforms.ToNumpy(1/255.0), # transforms.RandomGammaImg((0.7,1.5)), # transforms.RandomBrightnessImg(0.2), # transforms.RandomContrastImg((0.8, 1.2)), # transforms.RandomGaussianNoiseImg(0.02), # transforms.ToNumpy(255.0), transforms.ToTensor(convert_pix_range=False), normalize ]) # data_dir = '/home/hzjiang/workspace/Data/CityScapes' data_dir = '/home/hzjiang/workspace/Data/KITTI_Semantics' train_data = SegList(data_dir, 'train',
def train_seg(args): batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size checkpoint_dir = args.checkpoint_dir print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) pretrained_base = args.pretrained_base # print(dla_up.__dict__.get(args.arch)) single_model = dla_up.__dict__.get(args.arch)(classes=args.classes, down_ratio=args.down) single_model = convert_model(single_model) model = torch.nn.DataParallel(single_model).cuda() print('model_created') if args.edge_weight > 0: weight = torch.from_numpy( np.array([1, args.edge_weight], dtype=np.float32)) # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) criterion = nn.NLLLoss2d(ignore_index=-1, weight=weight) else: # criterion = nn.NLLLoss2d(ignore_index=255) criterion = nn.NLLLoss2d(ignore_index=-1) criterion.cuda() t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.append(transforms.RandomCrop(crop_size)) #TODO if args.random_color: t.append(transforms.RandomJitter(0.4, 0.4, 0.4)) t.extend([transforms.RandomHorizontalFlip()]) #TODO t_val = [] t_val.append(transforms.RandomCrop(crop_size)) dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/image_02/' dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/' + args.target + '/' my_train = BasicDataset(dir_img, dir_mask, transforms.Compose(t), is_train=True) val_dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/image_02/' val_dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/' + args.target + '/' my_val = BasicDataset(val_dir_img, val_dir_mask, transforms.Compose(t_val), is_train=True) train_loader = torch.utils.data.DataLoader(my_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( my_val, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) #TODO batch_size print("loader created") optimizer = torch.optim.SGD(single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = None #TODO cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) confusion_labels = np.arange(0, 5) val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) if args.evaluate: confusion_labels = np.arange(0, 2) val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1, reduce=True) validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) return writer = SummaryWriter(comment=args.log) # TODO test val # print("test val") # prec1 = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) for epoch in range(start_epoch, args.epochs): train_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) lr = adjust_learning_rate(args, optimizer, epoch) print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, lr_scheduler, confusion_matrix=train_confusion_matrix, writer=writer) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict() }, is_best=False, filename=checkpoint_path) # evaluate on validation set val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) prec1, loss_val = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) writer.add_scalar('mIoU/epoch', prec1, epoch + 1) writer.add_scalar('loss/epoch', loss_val, epoch + 1) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path) writer.close()
def train_seg(args): batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) pretrained_base = args.pretrained_base single_model = dla_up.__dict__.get(args.arch)(args.classes, pretrained_base, down_ratio=args.down) model = torch.nn.DataParallel(single_model).cuda() if args.edge_weight > 0: weight = torch.from_numpy( np.array([1, args.edge_weight], dtype=np.float32)) criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) else: criterion = nn.NLLLoss2d(ignore_index=255) criterion.cuda() data_dir = args.data_dir info = dataset.load_dataset_info(data_dir) normalize = transforms.Normalize(mean=info.mean, std=info.std) t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.append(transforms.RandomCrop(crop_size)) if args.random_color: t.append(transforms.RandomJitter(0.4, 0.4, 0.4)) t.extend( [transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) train_loader = torch.utils.data.DataLoader(SegList( data_dir, 'train', transforms.Compose(t), binary=(args.classes == 2)), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( SegList( data_dir, 'val', transforms.Compose([ transforms.RandomCrop(crop_size), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), binary=(args.classes == 2)), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) optimizer = torch.optim.SGD(single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy) return for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args, optimizer, epoch) print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, eval_score=accuracy) # evaluate on validation set prec1 = validate(val_loader, model, criterion, eval_score=accuracy) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) checkpoint_path = 'checkpoint_latest.pth.tar' save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path)
def get_loader(args, split, out_name=False, customized_task_set=None): """Returns data loader depending on dataset and split""" dataset = args.dataset loader = None if customized_task_set is None: task_set = args.task_set else: task_set = customized_task_set if dataset == 'taskonomy': print('using taskonomy') if split == 'train': loader = torch.utils.data.DataLoader(TaskonomyLoader( root=args.data_dir, is_training=True, threshold=1200, task_set=task_set, model_whitelist=None, model_limit=30, output_size=None), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if split == 'val': loader = torch.utils.data.DataLoader( TaskonomyLoader(root=args.data_dir, is_training=False, threshold=1200, task_set=task_set, model_whitelist=None, model_limit=30, output_size=None), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if split == 'adv_val': loader = torch.utils.data.DataLoader( TaskonomyLoader(root=args.data_dir, is_training=False, threshold=1200, task_set=task_set, model_whitelist=None, model_limit=30, output_size=None), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif dataset == 'voc': if split == 'train': loader = torch.utils.data.DataLoader(VOCSegmentation( args=args, base_dir=args.data_dir, split='train'), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'val': loader = torch.utils.data.DataLoader( VOCSegmentation(args=args, base_dir=args.data_dir, split='val', out_name=out_name), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'adv_val': loader = torch.utils.data.DataLoader(VOCSegmentation( args=args, base_dir=args.data_dir, split='val', out_name=out_name), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif dataset == 'coco': if split == 'train': loader = torch.utils.data.DataLoader(COCOSegmentation( args=args, base_dir=args.data_dir, split='train'), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'val': loader = torch.utils.data.DataLoader( COCOSegmentation(args=args, base_dir=args.data_dir, split='val', out_name=out_name), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'adv_val': loader = torch.utils.data.DataLoader(COCOSegmentation( args=args, base_dir=args.data_dir, split='val', out_name=out_name), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif dataset == 'cityscape': data_dir = args.data_dir info = json.load(open(join(data_dir, 'info.json'), 'r')) normalize = transforms.Normalize(mean=info['mean'], std=info['std']) t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.extend([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) task_set_present = hasattr(args, 'task_set') if split == 'train': if task_set_present: print( "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n" ) loader = torch.utils.data.DataLoader( SegDepthList(data_dir, 'train', transforms.Compose(t), list_dir=args.list_dir), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) else: loader = torch.utils.data.DataLoader( SegList(data_dir, 'train', transforms.Compose(t), list_dir=args.list_dir), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'val': if args.task_set != []: print( "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n" ) loader = torch.utils.data.DataLoader( SegDepthList(data_dir, 'val', transforms.Compose([ transforms.ToTensor(), normalize, ]), list_dir=args.list_dir, out_name=out_name), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) else: print("city test eval!") loader = torch.utils.data.DataLoader( SegList(data_dir, 'val', transforms.Compose([ transforms.ToTensor(), normalize, ]), list_dir=args.list_dir, out_name=out_name), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'adv_val': # has batch size 1 if task_set_present: print( "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n" ) loader = torch.utils.data.DataLoader(SegDepthList( data_dir, 'val', transforms.Compose([ transforms.ToTensor(), normalize, ]), list_dir=args.list_dir, out_name=out_name), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) else: loader = torch.utils.data.DataLoader(SegList( data_dir, 'val', transforms.Compose([ transforms.ToTensor(), normalize, ]), list_dir=args.list_dir, out_name=out_name), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) return loader
self.image_list = [line.strip() for line in open(image_path, 'r')] if exists(label_path): self.label_list = [line.strip() for line in open(label_path, 'r')] assert len(self.image_list) == len(self.label_list) if __name__ == "__main__": #Testing the dataloader data_dir = "/home/amogh/data/datasets/drn_data/DRN-move/cityscape_dataset/" info = json.load(open(join(data_dir, 'info.json'), 'r')) normalize = transforms.Normalize(mean=info['mean'], std=info['std']) t = [] # t.append(transforms.RandomRotate(0)) # t.append(transforms.RandomScale(0)) t.extend([ transforms.RandomCrop(896), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) # loader = SegDepthList(data_dir="/home/amogh/data/datasets/drn_data/DRN-move/cityscape_dataset/", loader = torch.utils.data.DataLoader(SegDepthList(data_dir, 'train', transforms.Compose(t), list_dir=None), batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=True)
def train_seg(args): writer = SummaryWriter(comment=args.log) batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size checkpoint_dir = args.checkpoint_dir print(' '.join(sys.argv)) # logger.info(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) pretrained_base = args.pretrained_base # print(dla_up.__dict__.get(args.arch)) single_model = dla_up.__dict__.get(args.arch)(classes=args.classes, down_ratio=args.down) model = torch.nn.DataParallel(single_model).cuda() print('model_created') if args.bg_weight > 0: weight_array = np.ones(args.classes, dtype=np.float32) weight_array[0] = args.bg_weight weight = torch.from_numpy(weight_array) # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) else: # criterion = nn.NLLLoss2d(ignore_index=255) criterion = nn.NLLLoss2d(ignore_index=255) criterion.cuda() t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.append(transforms.RandomCrop(crop_size)) #TODO if args.random_color: t.append(transforms.RandomJitter(0.4, 0.4, 0.4)) t.extend([transforms.RandomHorizontalFlip()]) #TODO t_val = [] t_val.append(transforms.RandomCrop(crop_size)) train_json = '/shared/xudongliu/COCO/annotation2017/annotations/instances_train2017.json' train_root = '/shared/xudongliu/COCO/train2017/train2017' my_train = COCOSeg(train_root, train_json, transforms.Compose(t), is_train=True) val_json = '/shared/xudongliu/COCO/annotation2017/annotations/instances_val2017.json' val_root = '/shared/xudongliu/COCO/2017val/val2017' my_val = COCOSeg(val_root, val_json, transforms.Compose(t_val), is_train=True) train_loader = torch.utils.data.DataLoader(my_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( my_val, batch_size=20, shuffle=False, num_workers=num_workers, pin_memory=True) #TODO batch_size print("loader created") # optimizer = torch.optim.Adam(single_model.optim_parameters(), # args.lr, # weight_decay=args.weight_decay) #TODO adam optimizer optimizer = torch.optim.SGD( single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #TODO adam optimizer lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=32) #TODO cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy) return # TODO test val # print("test val") # prec1 = validate(val_loader, model, criterion, eval_score=accuracy) for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args, optimizer, epoch) print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, lr_scheduler, eval_score=accuracy, writer=writer) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict() }, is_best=False, filename=checkpoint_path) # evaluate on validation set prec1, loss_val, recall_val = validate(val_loader, model, criterion, eval_score=accuracy) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) writer.add_scalar('accuracy/epoch', prec1, epoch + 1) writer.add_scalar('loss/epoch', loss_val, epoch + 1) writer.add_scalar('recall/epoch', recall_val, epoch + 1) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path) writer.close()
def train_cnn(args): batch_size = args.batch_size num_workers = args.workers crop_size = cfg['CROP_SIZE'] for k, v in args.__dict__.items(): print(k, ':', v) single_model = QGCNN() model = torch.nn.DataParallel(single_model) if cfg['FEATS']: feat_names, weights = zip(*(tuple(*f.items()) for f in cfg['FEATS'])) else: feat_names, weights = None, None criterion = ComLoss(cfg['IQA_MODEL'], weights, feat_names, patch_size=cfg['PATCH_SIZE'], pixel_criterion=cfg['CRITERION']) criterion.cuda() # Data loading data_dir = cfg['DATA_DIR'] list_dir = cfg['LIST_DIR'] t = [ transforms.RandomCrop(crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor() ] # Note that the cropsize could have a significant influence, # i.e., with a small cropsize the model would get overfitted # easily thus hard to train train_loader = torch.utils.data.DataLoader(DataList(data_dir, 'train', transforms.Compose(t), list_dir=list_dir), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) # The cropsize of the validation set dramatically affects the # evaluation accuracy, which means the quality of the whole # image might be very different from that of its cropped patches. # # Try setting batch_size = 1 and no crop (disable RandomCrop) # to improve the effect of early stopping. val_loader = DataList(data_dir, 'val', transforms.Compose([transforms.ToTensor()]), list_dir=list_dir) optimizer = torch.optim.Adam(single_model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) cudnn.benchmark = True weight_dir = join(out_dir, 'weights/') if not exists(weight_dir): os.mkdir(weight_dir) best_prec = 0 start_epoch = 0 # Optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): logger_s.info("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) logger_s.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) logger_f.info("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: logger_f.warning("=> no checkpoint found at '{}'".format( args.resume)) if args.evaluate: validate(val_loader, model.cuda(), criterion, eval_score=accuracy) return for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args, optimizer, epoch) if criterion.weights is not None and (epoch + 1) % 100 == 0: criterion.weights /= 10.0 logger_s.info('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model.cuda(), criterion, optimizer, epoch, eval_score=accuracy) # Evaluate on validation set prec = validate(val_loader, model.cuda(), criterion, eval_score=accuracy) is_best = prec > best_prec best_prec = max(prec, best_prec) logger_s.info('current best {:.6f}'.format(best_prec)) checkpoint_path = join(weight_dir, 'checkpoint_latest.pkl') save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec': best_prec, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.store_interval == 0: history_path = join(weight_dir, 'checkpoint_{:03d}.pkl'.format(epoch + 1)) shutil.copyfile(checkpoint_path, history_path)
def train_seg(args): batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) single_model = DRNSeg(args.arch, args.classes, None, pretrained=True) if args.pretrained: single_model.load_state_dict(torch.load(args.pretrained)) model = torch.nn.DataParallel(single_model).cuda() criterion = nn.NLLLoss2d(ignore_index=255) criterion.cuda() # Data loading code data_dir = args.data_dir info = json.load(open(join(data_dir, 'info.json'), 'r')) normalize = transforms.Normalize(mean=info['mean'], std=info['std']) t = [] if args.downsample: t.append(transforms.Scale(0.5)) if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.extend([ transforms.RandomCrop(crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) train_loader = torch.utils.data.DataLoader(SegList(data_dir, 'train', transforms.Compose(t)), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(SegList( data_dir, 'val', transforms.Compose([ transforms.RandomCrop(crop_size), transforms.ToTensor(), normalize, ])), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=True) # define loss function (criterion) and pptimizer optimizer = torch.optim.SGD(single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy) return for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args, optimizer, epoch) logger.info('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, eval_score=accuracy) # evaluate on validation set prec1 = validate(val_loader, model, criterion, eval_score=accuracy) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) checkpoint_path = 'checkpoint_latest.pth.tar' save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path, prefix=args.arch) if (epoch + 1) % 10 == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path) # save historical data to s3 upload_to_s3(history_path, prefix=args.arch) # save latest checkpoint to s3 try: upload_to_s3(checkpoint_path, prefix=args.arch) except: logging.info('failed to upload latest checkpoint to s3')