def train_seg(args): rand_state = np.random.RandomState(1311) torch.manual_seed(1311) device = 'cuda' if (torch.cuda.is_available()) else 'cpu' # We have 2975 images total in the training set, so let's choose 500 for 3 cycles, # 1500 images total (~1/2 of total) images_per_cycle = 150 batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) # Data loading code data_dir = args.data_dir info = json.load(open(join(data_dir, 'info.json'), 'r')) normalize = data_transforms.Normalize(mean=info['mean'], std=info['std']) t = [] if args.random_rotate > 0: t.append(data_transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(data_transforms.RandomScale(args.random_scale)) t.extend([ data_transforms.RandomCrop(crop_size), data_transforms.RandomHorizontalFlip(), data_transforms.ToTensor(), normalize ]) dataset = SegList(data_dir, 'train', data_transforms.Compose(t), list_dir=args.list_dir) training_dataset_no_augmentation = SegList( data_dir, 'train', data_transforms.Compose([data_transforms.ToTensor(), normalize]), list_dir=args.list_dir) unlabeled_idx = list(range(len(dataset))) labeled_idx = [] validation_accuracies = list() validation_mAPs = list() progress = tqdm.tqdm(range(10)) for cycle in progress: single_model = DRNSeg(args.arch, args.classes, None, pretrained=True) if args.pretrained: single_model.load_state_dict(torch.load(args.pretrained)) # Wrap our model in Active Learning Model. if args.use_loss_prediction_al: single_model = ActiveLearning(single_model, global_avg_pool_size=6, fc_width=256) elif args.use_discriminative_al: single_model = DiscriminativeActiveLearning(single_model) optim_parameters = single_model.optim_parameters() model = torch.nn.DataParallel(single_model).cuda() # Don't apply a 'mean' reduction, we need the whole loss vector. criterion = nn.NLLLoss(ignore_index=255, reduction='none') criterion.cuda() if args.choose_images_with_highest_loss: # Choosing images based on the ground truth labels. # We want to check if predicting loss with 100% accuracy would result to # a good active learning algorithm. new_indices, entropies = choose_new_labeled_indices_using_gt( model, cycle, rand_state, unlabeled_idx, training_dataset_no_augmentation, device, criterion, images_per_cycle) else: new_indices, entropies = choose_new_labeled_indices( model, training_dataset_no_augmentation, cycle, rand_state, labeled_idx, unlabeled_idx, device, images_per_cycle, args.use_loss_prediction_al, args.use_discriminative_al, input_pickle_file=None) labeled_idx.extend(new_indices) print("Running on {} labeled images.".format(len(labeled_idx))) if args.output_superannotate_csv_file is not None: # Write image paths to csv file which can be uploaded to annotate.online. write_entropies_csv(training_dataset_no_augmentation, new_indices, entropies, args.output_superannotate_csv_file) train_loader = torch.utils.data.DataLoader(data.Subset( dataset, labeled_idx), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(SegList( data_dir, 'val', data_transforms.Compose([ data_transforms.RandomCrop(crop_size), data_transforms.ToTensor(), normalize, ]), list_dir=args.list_dir), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=True) # define loss function (criterion) and optimizer. optimizer = torch.optim.SGD(optim_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True best_prec1 = 0 best_mAP = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy, num_classes=args.classes, use_loss_prediction_al=args.use_loss_prediction_al) return progress_epoch = tqdm.tqdm(range(start_epoch, args.epochs)) for epoch in progress_epoch: lr = adjust_learning_rate(args, optimizer, epoch) logger.info('Cycle {0} Epoch: [{1}]\tlr {2:.06f}'.format( cycle, epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, eval_score=accuracy, use_loss_prediction_al=args.use_loss_prediction_al, active_learning_lamda=args.lamda) # evaluate on validation set prec1, mAP1 = validate( val_loader, model, criterion, eval_score=accuracy, num_classes=args.classes, use_loss_prediction_al=args.use_loss_prediction_al) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) best_mAP = max(mAP1, best_mAP) checkpoint_path = os.path.join(args.save_path, 'checkpoint_latest.pth.tar') save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'best_mAP': best_mAP, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_iter == 0: history_path = os.path.join( args.save_path, 'checkpoint_{:03d}.pth.tar'.format(epoch + 1)) shutil.copyfile(checkpoint_path, history_path) validation_accuracies.append(best_prec1) validation_mAPs.append(best_mAP) print("{} accuracies: {} mAPs {}".format( "Active Learning" if args.use_loss_prediction_al else "Random", str(validation_accuracies), str(validation_mAPs)))
def train_seg(args): batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size checkpoint_dir = args.checkpoint_dir print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) pretrained_base = args.pretrained_base # print(dla_up.__dict__.get(args.arch)) single_model = dla_up.__dict__.get(args.arch)(classes=args.classes, down_ratio=args.down) single_model = convert_model(single_model) model = torch.nn.DataParallel(single_model).cuda() print('model_created') if args.edge_weight > 0: weight = torch.from_numpy( np.array([1, args.edge_weight], dtype=np.float32)) # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) criterion = nn.NLLLoss2d(ignore_index=-1, weight=weight) else: # criterion = nn.NLLLoss2d(ignore_index=255) criterion = nn.NLLLoss2d(ignore_index=-1) criterion.cuda() t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.append(transforms.RandomCrop(crop_size)) #TODO if args.random_color: t.append(transforms.RandomJitter(0.4, 0.4, 0.4)) t.extend([transforms.RandomHorizontalFlip()]) #TODO t_val = [] t_val.append(transforms.RandomCrop(crop_size)) dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/image_02/' dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/train/' + args.target + '/' my_train = BasicDataset(dir_img, dir_mask, transforms.Compose(t), is_train=True) val_dir_img = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/image_02/' val_dir_mask = '/shared/xudongliu/data/argoverse-tracking/argo_track_all/val/' + args.target + '/' my_val = BasicDataset(val_dir_img, val_dir_mask, transforms.Compose(t_val), is_train=True) train_loader = torch.utils.data.DataLoader(my_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( my_val, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) #TODO batch_size print("loader created") optimizer = torch.optim.SGD(single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = None #TODO cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) confusion_labels = np.arange(0, 5) val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) if args.evaluate: confusion_labels = np.arange(0, 2) val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1, reduce=True) validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) return writer = SummaryWriter(comment=args.log) # TODO test val # print("test val") # prec1 = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) for epoch in range(start_epoch, args.epochs): train_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) lr = adjust_learning_rate(args, optimizer, epoch) print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, lr_scheduler, confusion_matrix=train_confusion_matrix, writer=writer) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict() }, is_best=False, filename=checkpoint_path) # evaluate on validation set val_confusion_matrix = RunningConfusionMatrix(confusion_labels, ignore_label=-1) prec1, loss_val = validate(val_loader, model, criterion, confusion_matrix=val_confusion_matrix) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) writer.add_scalar('mIoU/epoch', prec1, epoch + 1) writer.add_scalar('loss/epoch', loss_val, epoch + 1) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path) writer.close()
def train_seg(args): batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) pretrained_base = args.pretrained_base single_model = dla_up.__dict__.get(args.arch)(args.classes, pretrained_base, down_ratio=args.down) model = torch.nn.DataParallel(single_model).cuda() if args.edge_weight > 0: weight = torch.from_numpy( np.array([1, args.edge_weight], dtype=np.float32)) criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) else: criterion = nn.NLLLoss2d(ignore_index=255) criterion.cuda() data_dir = args.data_dir info = dataset.load_dataset_info(data_dir) normalize = transforms.Normalize(mean=info.mean, std=info.std) t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.append(transforms.RandomCrop(crop_size)) if args.random_color: t.append(transforms.RandomJitter(0.4, 0.4, 0.4)) t.extend( [transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) train_loader = torch.utils.data.DataLoader(SegList( data_dir, 'train', transforms.Compose(t), binary=(args.classes == 2)), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( SegList( data_dir, 'val', transforms.Compose([ transforms.RandomCrop(crop_size), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), binary=(args.classes == 2)), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) optimizer = torch.optim.SGD(single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy) return for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args, optimizer, epoch) print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, eval_score=accuracy) # evaluate on validation set prec1 = validate(val_loader, model, criterion, eval_score=accuracy) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) checkpoint_path = 'checkpoint_latest.pth.tar' save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path)
def get_loader(args, split, out_name=False, customized_task_set=None): """Returns data loader depending on dataset and split""" dataset = args.dataset loader = None if customized_task_set is None: task_set = args.task_set else: task_set = customized_task_set if dataset == 'taskonomy': print('using taskonomy') if split == 'train': loader = torch.utils.data.DataLoader(TaskonomyLoader( root=args.data_dir, is_training=True, threshold=1200, task_set=task_set, model_whitelist=None, model_limit=30, output_size=None), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if split == 'val': loader = torch.utils.data.DataLoader( TaskonomyLoader(root=args.data_dir, is_training=False, threshold=1200, task_set=task_set, model_whitelist=None, model_limit=30, output_size=None), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if split == 'adv_val': loader = torch.utils.data.DataLoader( TaskonomyLoader(root=args.data_dir, is_training=False, threshold=1200, task_set=task_set, model_whitelist=None, model_limit=30, output_size=None), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif dataset == 'voc': if split == 'train': loader = torch.utils.data.DataLoader(VOCSegmentation( args=args, base_dir=args.data_dir, split='train'), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'val': loader = torch.utils.data.DataLoader( VOCSegmentation(args=args, base_dir=args.data_dir, split='val', out_name=out_name), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'adv_val': loader = torch.utils.data.DataLoader(VOCSegmentation( args=args, base_dir=args.data_dir, split='val', out_name=out_name), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif dataset == 'coco': if split == 'train': loader = torch.utils.data.DataLoader(COCOSegmentation( args=args, base_dir=args.data_dir, split='train'), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'val': loader = torch.utils.data.DataLoader( COCOSegmentation(args=args, base_dir=args.data_dir, split='val', out_name=out_name), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'adv_val': loader = torch.utils.data.DataLoader(COCOSegmentation( args=args, base_dir=args.data_dir, split='val', out_name=out_name), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif dataset == 'cityscape': data_dir = args.data_dir info = json.load(open(join(data_dir, 'info.json'), 'r')) normalize = transforms.Normalize(mean=info['mean'], std=info['std']) t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.extend([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) task_set_present = hasattr(args, 'task_set') if split == 'train': if task_set_present: print( "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n" ) loader = torch.utils.data.DataLoader( SegDepthList(data_dir, 'train', transforms.Compose(t), list_dir=args.list_dir), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) else: loader = torch.utils.data.DataLoader( SegList(data_dir, 'train', transforms.Compose(t), list_dir=args.list_dir), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'val': if args.task_set != []: print( "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n" ) loader = torch.utils.data.DataLoader( SegDepthList(data_dir, 'val', transforms.Compose([ transforms.ToTensor(), normalize, ]), list_dir=args.list_dir, out_name=out_name), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) else: print("city test eval!") loader = torch.utils.data.DataLoader( SegList(data_dir, 'val', transforms.Compose([ transforms.ToTensor(), normalize, ]), list_dir=args.list_dir, out_name=out_name), batch_size=args.test_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) elif split == 'adv_val': # has batch size 1 if task_set_present: print( "\nCAUTION: THE DATALOADER IS FOR MULTITASK ON CITYSCAPE\n" ) loader = torch.utils.data.DataLoader(SegDepthList( data_dir, 'val', transforms.Compose([ transforms.ToTensor(), normalize, ]), list_dir=args.list_dir, out_name=out_name), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) else: loader = torch.utils.data.DataLoader(SegList( data_dir, 'val', transforms.Compose([ transforms.ToTensor(), normalize, ]), list_dir=args.list_dir, out_name=out_name), batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) return loader
def train_seg(args): writer = SummaryWriter(comment=args.log) batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size checkpoint_dir = args.checkpoint_dir print(' '.join(sys.argv)) # logger.info(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) pretrained_base = args.pretrained_base # print(dla_up.__dict__.get(args.arch)) single_model = dla_up.__dict__.get(args.arch)(classes=args.classes, down_ratio=args.down) model = torch.nn.DataParallel(single_model).cuda() print('model_created') if args.bg_weight > 0: weight_array = np.ones(args.classes, dtype=np.float32) weight_array[0] = args.bg_weight weight = torch.from_numpy(weight_array) # criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) else: # criterion = nn.NLLLoss2d(ignore_index=255) criterion = nn.NLLLoss2d(ignore_index=255) criterion.cuda() t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.append(transforms.RandomCrop(crop_size)) #TODO if args.random_color: t.append(transforms.RandomJitter(0.4, 0.4, 0.4)) t.extend([transforms.RandomHorizontalFlip()]) #TODO t_val = [] t_val.append(transforms.RandomCrop(crop_size)) train_json = '/shared/xudongliu/COCO/annotation2017/annotations/instances_train2017.json' train_root = '/shared/xudongliu/COCO/train2017/train2017' my_train = COCOSeg(train_root, train_json, transforms.Compose(t), is_train=True) val_json = '/shared/xudongliu/COCO/annotation2017/annotations/instances_val2017.json' val_root = '/shared/xudongliu/COCO/2017val/val2017' my_val = COCOSeg(val_root, val_json, transforms.Compose(t_val), is_train=True) train_loader = torch.utils.data.DataLoader(my_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( my_val, batch_size=20, shuffle=False, num_workers=num_workers, pin_memory=True) #TODO batch_size print("loader created") # optimizer = torch.optim.Adam(single_model.optim_parameters(), # args.lr, # weight_decay=args.weight_decay) #TODO adam optimizer optimizer = torch.optim.SGD( single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) #TODO adam optimizer lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=32) #TODO cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy) return # TODO test val # print("test val") # prec1 = validate(val_loader, model, criterion, eval_score=accuracy) for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args, optimizer, epoch) print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, lr_scheduler, eval_score=accuracy, writer=writer) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict() }, is_best=False, filename=checkpoint_path) # evaluate on validation set prec1, loss_val, recall_val = validate(val_loader, model, criterion, eval_score=accuracy) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) writer.add_scalar('accuracy/epoch', prec1, epoch + 1) writer.add_scalar('loss/epoch', loss_val, epoch + 1) writer.add_scalar('recall/epoch', recall_val, epoch + 1) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(epoch)) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path) writer.close()
def train_seg(args): batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) single_model = DRNSeg(args.arch, args.classes, None, pretrained=True) if args.pretrained: single_model.load_state_dict(torch.load(args.pretrained)) model = torch.nn.DataParallel(single_model).cuda() criterion = nn.NLLLoss2d(ignore_index=255) criterion.cuda() # Data loading code data_dir = args.data_dir info = json.load(open(join(data_dir, 'info.json'), 'r')) normalize = transforms.Normalize(mean=info['mean'], std=info['std']) t = [] if args.downsample: t.append(transforms.Scale(0.5)) if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.extend([ transforms.RandomCrop(crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) train_loader = torch.utils.data.DataLoader(SegList(data_dir, 'train', transforms.Compose(t)), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(SegList( data_dir, 'val', transforms.Compose([ transforms.RandomCrop(crop_size), transforms.ToTensor(), normalize, ])), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, drop_last=True) # define loss function (criterion) and pptimizer optimizer = torch.optim.SGD(single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy) return for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args, optimizer, epoch) logger.info('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, eval_score=accuracy) # evaluate on validation set prec1 = validate(val_loader, model, criterion, eval_score=accuracy) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) checkpoint_path = 'checkpoint_latest.pth.tar' save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path, prefix=args.arch) if (epoch + 1) % 10 == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path) # save historical data to s3 upload_to_s3(history_path, prefix=args.arch) # save latest checkpoint to s3 try: upload_to_s3(checkpoint_path, prefix=args.arch) except: logging.info('failed to upload latest checkpoint to s3')