def test_model(args): # create model model = dla.__dict__[args.arch](pretrained=args.pretrained, pool_size=args.crop_size // 32) model = torch.nn.DataParallel(model) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {} prec {:.03f}) " .format(args.resume, checkpoint['epoch'], best_prec1)) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True data = dataset.get_data(args.data_name) if data is None: data = dataset.load_dataset_info(args.data, data_name=args.data_name) if data is None: raise ValueError('{} is not pre-defined in dataset.py and info.json ' 'does not exist in {}', args.data_name, args.data) # Data loading code valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=data.mean, std=data.std) if args.crop_10: t = transforms.Compose([ transforms.Resize(args.scale_size), transforms.ToTensor(), normalize]) else: t = transforms.Compose([ transforms.Resize(args.scale_size), transforms.CenterCrop(args.crop_size), transforms.ToTensor(), normalize]) val_loader = torch.utils.data.DataLoader( ImageFolder(valdir, t, out_name=args.crop_10), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and pptimizer criterion = nn.CrossEntropyLoss() if args.cuda: model = model.cuda() criterion = criterion.cuda() if args.crop_10: validate_10(args, val_loader, model, '{}_i_{}_c_10.txt'.format(args.arch, args.start_epoch)) else: validate(args, val_loader, model, criterion)
def test_seg(args): batch_size = args.batch_size num_workers = args.workers phase = args.phase for k, v in args.__dict__.items(): print(k, ':', v) single_model = dla_up.__dict__.get(args.arch)(args.classes, down_ratio=args.down) model = torch.nn.DataParallel(single_model).cuda() data_dir = args.data_dir info = dataset.load_dataset_info(data_dir) normalize = transforms.Normalize(mean=info.mean, std=info.std) # scales = [0.5, 0.75, 1.25, 1.5, 1.75] scales = [0.5, 0.75, 1.25, 1.5] t = [] if args.crop_size > 0: t.append(transforms.PadToSize(args.crop_size)) t.extend([transforms.ToTensor(), normalize]) if args.ms: data = SegListMS(data_dir, phase, transforms.Compose(t), scales) else: data = SegList(data_dir, phase, transforms.Compose(t), out_name=True, out_size=True, binary=args.classes == 2) test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False) cudnn.benchmark = True # optionally resume from a checkpoint start_epoch = 0 if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) out_dir = '{}_{:03d}_{}'.format(args.arch, start_epoch, phase) if len(args.test_suffix) > 0: out_dir += '_' + args.test_suffix if args.ms: out_dir += '_ms' if args.ms: mAP = test_ms(test_loader, model, args.classes, save_vis=True, has_gt=phase != 'test' or args.with_gt, output_dir=out_dir, scales=scales) else: mAP = test(test_loader, model, args.classes, save_vis=True, has_gt=phase != 'test' or args.with_gt, output_dir=out_dir) print('mAP: ', mAP)
def train_seg(args): batch_size = args.batch_size num_workers = args.workers crop_size = args.crop_size print(' '.join(sys.argv)) for k, v in args.__dict__.items(): print(k, ':', v) pretrained_base = args.pretrained_base single_model = dla_up.__dict__.get(args.arch)(args.classes, pretrained_base, down_ratio=args.down) model = torch.nn.DataParallel(single_model).cuda() if args.edge_weight > 0: weight = torch.from_numpy( np.array([1, args.edge_weight], dtype=np.float32)) criterion = nn.NLLLoss2d(ignore_index=255, weight=weight) else: criterion = nn.NLLLoss2d(ignore_index=255) criterion.cuda() data_dir = args.data_dir info = dataset.load_dataset_info(data_dir) normalize = transforms.Normalize(mean=info.mean, std=info.std) t = [] if args.random_rotate > 0: t.append(transforms.RandomRotate(args.random_rotate)) if args.random_scale > 0: t.append(transforms.RandomScale(args.random_scale)) t.append(transforms.RandomCrop(crop_size)) if args.random_color: t.append(transforms.RandomJitter(0.4, 0.4, 0.4)) t.extend( [transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) train_loader = torch.utils.data.DataLoader(SegList( data_dir, 'train', transforms.Compose(t), binary=(args.classes == 2)), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( SegList( data_dir, 'val', transforms.Compose([ transforms.RandomCrop(crop_size), # transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), binary=(args.classes == 2)), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) optimizer = torch.optim.SGD(single_model.optim_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) cudnn.benchmark = True best_prec1 = 0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: validate(val_loader, model, criterion, eval_score=accuracy) return for epoch in range(start_epoch, args.epochs): lr = adjust_learning_rate(args, optimizer, epoch) print('Epoch: [{0}]\tlr {1:.06f}'.format(epoch, lr)) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, eval_score=accuracy) # evaluate on validation set prec1 = validate(val_loader, model, criterion, eval_score=accuracy) is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) checkpoint_path = 'checkpoint_latest.pth.tar' save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.save_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path)
def train_net(net, paras): # parameters img_dir = paras.image_dir anno_path = paras.anno_path checkpoint_dir = paras.model_save_dir val_percent = 0.1 epochs = paras.epochs batch_size = paras.batch_size lr = paras.learning_rate num_workers = 2 # torch model saver saver = ModelSaver(max_save_num=5) # load dataset info dataset = load_dataset_info(img_dir, anno_path) train_set_info, valid_set_info = split_dataset_info(dataset, val_percent) # build dataloader building_trainset = Building_Dataset(train_set_info) building_validset = Building_Dataset(valid_set_info) train_dataloader = torch.utils.data.DataLoader(building_trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) valid_dataloader = torch.utils.data.DataLoader(building_validset, batch_size=batch_size, shuffle=False, num_workers=num_workers) # optimizer optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) # loss function #criterion = nn.L1Loss(reduce=True, size_average=True) criterion = nn.BCELoss() train_num = len(building_trainset) valid_num = len(building_validset) print(''' Starting training: Total Epochs: {} Batch size: {} Learning rate: {} Training size: {} Validation size: {} Checkpoints save dir: {} '''.format(epochs, batch_size, lr, train_num, valid_num, checkpoint_dir)) # ------------------------ # start training... # ------------------------ best_valid_loss = 1000 for epoch in range(1, epochs + 1): print('Starting epoch {}/{}.'.format(epoch, epochs)) # training net.train() epoch_loss = 0 for idx, data in enumerate(train_dataloader): imgs, true_masks = data imgs = imgs.cuda() true_masks = true_masks.cuda() pred_masks = net(imgs) # compute loss loss = criterion(pred_masks, true_masks) epoch_loss += loss.item() if idx % 10 == 0: print(f'{idx}/{len(train_dataloader)}, loss: {loss.item()}') optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss = epoch_loss / len(train_dataloader) print('Epoch finished ! Loss: {}\n'.format(epoch_loss)) # validation net.eval() valid_loss = 0 with torch.no_grad(): for idx, data in enumerate(valid_dataloader): if idx % 10 == 0: print(idx, '/', len(valid_dataloader)) imgs, true_masks = data imgs = imgs.cuda() true_masks = true_masks.cuda() # inference pred_masks = net(imgs) # compute loss loss = criterion(pred_masks, true_masks) valid_loss += loss.item() valid_loss = valid_loss / len(valid_dataloader) print('Validation finished ! Loss:{} Best Loss before:{}\n'.format( valid_loss, best_valid_loss)) # save check_point if valid_loss < best_valid_loss: best_valid_loss = valid_loss print('New best model find, Checkpoint {} saving...'.format(epoch)) model_save_path = os.path.join( checkpoint_dir, '{}_CP{}.pth'.format(best_valid_loss, epoch)) #torch.save(net.state_dict(), model_save_path) saver.save_new_model(net, model_save_path)
def run_training(args): model = dla.__dict__[args.arch]( pretrained=args.pretrained, num_classes=args.classes, pool_size=args.crop_size // 32) model = torch.nn.DataParallel(model) best_prec1 = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True data = dataset.get_data(args.data_name) if data is None: data = dataset.load_dataset_info(args.data, data_name=args.data_name) if data is None: raise ValueError('{} is not pre-defined in dataset.py and info.json ' 'does not exist in {}', args.data_name, args.data) # Data loading code traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = data_transforms.Normalize(mean=data.mean, std=data.std) tt = [data_transforms.RandomResizedCrop( args.crop_size, min_area_ratio=args.min_area_ratio, aspect_ratio=args.aspect_ratio)] if data.eigval is not None and data.eigvec is not None \ and args.random_color: ligiting = data_transforms.Lighting(0.1, data.eigval, data.eigvec) jitter = data_transforms.RandomJitter(0.4, 0.4, 0.4) tt.extend([jitter, ligiting]) tt.extend([data_transforms.RandomHorizontalFlip(), data_transforms.ToTensor(), normalize]) train_loader = torch.utils.data.DataLoader( datasets.ImageFolder(traindir, data_transforms.Compose(tt)), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ transforms.Resize(args.scale_size), transforms.CenterCrop(args.crop_size), transforms.ToTensor(), normalize ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # define loss function (criterion) and pptimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.cuda: model = model.cuda() criterion = criterion.cuda() if args.evaluate: validate(args, val_loader, model, criterion) return for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(args, optimizer, epoch) # train for one epoch train(args, train_loader, model, criterion, optimizer, epoch) # evaluate on validation set prec1 = validate(args, val_loader, model, criterion) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) checkpoint_path = 'checkpoint_latest.pth.tar' save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best, filename=checkpoint_path) if (epoch + 1) % args.check_freq == 0: history_path = 'checkpoint_{:03d}.pth.tar'.format(epoch + 1) shutil.copyfile(checkpoint_path, history_path)