def __main__(args): #initializing pretrained network pspnet = PSPNet(n_classes=cityscapes.num_classes).cuda(gpu0) pspnet.load_pretrained_model(model_path=pspnet_path) #transformation and loading dataset mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) val_input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std) ]) target_transform = standard_transforms.Compose( [extended_transforms.MaskToTensor()]) restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage(), ]) visualize = standard_transforms.ToTensor() val_set = cityscapes.CityScapes('val', transform=val_input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False) validate(pspnet, val_loader, cityscapes.num_classes, args, restore_transform, visualize)
def main(): net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda() if len(args['snapshot']) == 0: curr_epoch = 1 args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} else: print('training resumes from ' + args['snapshot']) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} net.train() mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) short_size = int(min(args['input_size']) / 0.875) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(short_size), joint_transforms.RandomCrop(args['input_size']), joint_transforms.RandomHorizontallyFlip() ]) val_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(short_size), joint_transforms.CenterCrop(args['input_size']) ]) input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.Lambda(lambda x: x.div_(255)), standard_transforms.ToPILImage(), extended_transforms.FlipChannels() ]) visualize = standard_transforms.ToTensor() train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False) criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda() optimizer = optim.Adam([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': args['lr'], 'weight_decay': args['weight_decay']} ], betas=(args['momentum'], 0.999)) if len(args['snapshot']) > 0: optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10, verbose=True) for epoch in range(curr_epoch, args['epoch_num'] + 1): train(train_loader, net, criterion, optimizer, epoch, args) val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize) scheduler.step(val_loss)
def main(): """Create the model and start the training.""" args = get_arguments() w, h = map(int, args.input_size.split(',')) w_target, h_target = map(int, args.input_size_target.split(',')) # Create network student_net = FCN8s(args.num_classes, args.model_path_prefix) student_net = torch.nn.DataParallel(student_net) student_net = student_net.cuda() mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) train_joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), ]) input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std), ]) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std), ]) target_transform = extended_transforms.MaskToTensor() # show img restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.Lambda(lambda x: x.div_(255)), standard_transforms.ToPILImage(), extended_transforms.FlipChannels(), ]) visualize = standard_transforms.ToTensor() if '5' in args.data_dir: src_dataset = GTA5DataSetLMDB( args.data_dir, args.data_list, joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform, ) else: src_dataset = CityscapesDataSetLMDB( args.data_dir, args.data_list, joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform, ) src_loader = data.DataLoader(src_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) tgt_val_dataset = CityscapesDataSetLMDB( args.data_dir_target, args.data_list_target, # no val resize # joint_transform=val_joint_transform, transform=val_input_transform, target_transform=target_transform, ) tgt_val_loader = data.DataLoader( tgt_val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, ) optimizer = optim.SGD(student_net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) # optimizer = optim.Adam( # student_net.parameters(), lr=args.learning_rate, # weight_decay=args.weight_decay # ) student_params = list(student_net.parameters()) # interp = partial( # nn.functional.interpolate, # size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True # ) # interp_tgt = partial( # nn.functional.interpolate, # size=(h_target, w_target), mode='bilinear', align_corners=True # ) upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear') n_class = args.num_classes # src_criterion = torch.nn.CrossEntropyLoss( # ignore_index=255, reduction='sum') src_criterion = torch.nn.CrossEntropyLoss(ignore_index=255, size_average=False) num_batches = len(src_loader) highest = 0 for epoch in range(args.num_epoch): cls_loss_rec = AverageMeter() aug_loss_rec = AverageMeter() mask_rec = AverageMeter() confidence_rec = AverageMeter() miu_rec = AverageMeter() data_time_rec = AverageMeter() batch_time_rec = AverageMeter() # load_time_rec = AverageMeter() # trans_time_rec = AverageMeter() tem_time = time.time() for batch_index, src_data in enumerate(src_loader): student_net.train() optimizer.zero_grad() # train with source # src_images, src_label, src_img_name, (load_time, trans_time) = src_data src_images, src_label, src_img_name = src_data src_images = src_images.cuda() src_label = src_label.cuda() data_time_rec.update(time.time() - tem_time) src_output = student_net(src_images) # src_output = interp(src_output) # Segmentation Loss cls_loss_value = src_criterion(src_output, src_label) cls_loss_value /= src_images.shape[0] total_loss = cls_loss_value total_loss.backward() optimizer.step() _, predict_labels = torch.max(src_output, 1) lbl_pred = predict_labels.detach().cpu().numpy() lbl_true = src_label.detach().cpu().numpy() _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19) cls_loss_rec.update(cls_loss_value.detach_().item()) miu_rec.update(mean_iu) # load_time_rec.update(torch.mean(load_time).item()) # trans_time_rec.update(torch.mean(trans_time).item()) batch_time_rec.update(time.time() - tem_time) tem_time = time.time() if (batch_index + 1) % args.print_freq == 0: print( f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t' f'Time: {batch_time_rec.avg:.2f} ' f'Data: {data_time_rec.avg:.2f} ' # f'Load: {load_time_rec.avg:.2f} ' # f'Trans: {trans_time_rec.avg:.2f} ' f'Mean iu: {miu_rec.avg*100:.1f} ' f'CLS: {cls_loss_rec.avg:.2f}') miu = test_miou(student_net, tgt_val_loader, upsample, './dataset/info.json') if miu > highest: torch.save(student_net.module.state_dict(), osp.join(args.snapshot_dir, f'final_fcn.pth')) highest = miu print('>' * 50 + f'save highest with {miu:.2%}')
def main(args): writer = SummaryWriter(log_dir=args.tensorboard_log_dir) w, h = map(int, args.input_size.split(',')) w_target, h_target = map(int, args.input_size_target.split(',')) joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), joint_transforms.RandomHorizontallyFlip(), ]) normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) tgt_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.ToPILImage() if args.seg_net == 'fcn': mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std), ]) else: normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) tgt_dataset = Cityscapes16DataSetLMDB( args.data_dir_target, args.data_list_target, joint_transform=joint_transform, transform=tgt_input_transform, target_transform=target_transform, ) tgt_loader = data.DataLoader(tgt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) val_dataset = Cityscapes16DataSetLMDB( args.data_dir_val, args.data_list_val, transform=val_input_transform, target_transform=target_transform, ) val_loader = data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear', align_corners=True) if args.seg_net == 'fcn': net = FCN8s(args.n_classes, pretrained=False) net_static = FCN8s(args.n_classes, pretrained=False) file_name = os.path.join(args.resume, args.fcn_name) # for name, param in net.named_parameters(): # if 'feat' not in name: # param.requires_grad = False elif args.seg_net == 'deeplab_ibn': deeplab = resnet101_ibn_a_deeplab() file_name = os.path.join(args.resume, 'deeplab_ibn.pth') net.load_state_dict(torch.load(file_name)) net_static.load_state_dict(torch.load(file_name)) for param in net_static.parameters(): param.requires_grad = False optimizer = torch.optim.SGD(net.parameters(), args.learning_rate, args.momentum) net = torch.nn.DataParallel(net.cuda()) net_static = torch.nn.DataParallel(net_static.cuda()) # criterion = torch.nn.MSELoss() # criterion = torch.nn.SmoothL1Loss() criterion = torch.nn.CrossEntropyLoss(ignore_index=255) gen_model = define_G() gen_model.load_state_dict( torch.load(os.path.join(args.resume, args.gen_name))) gen_model.eval() for param in gen_model.parameters(): param.requires_grad = False gen_model = torch.nn.DataParallel(gen_model.cuda()) # for seg net def normalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): if args.seg_net == 'fcn': mean = [103.939, 116.779, 123.68] flip_x = torch.cat( [x[:, 2 - i, :, :].unsqueeze(1) for i in range(3)], dim=1, ) new_x = [] for tem_x in flip_x: tem_new_x = [] for c, m in zip(tem_x, mean): tem_new_x.append(c.mul(255.0).sub(m).unsqueeze(0)) new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0)) new_x = torch.cat(new_x, dim=0) return new_x else: for tem_x in x: for c, m, s in zip(tem_x, mean, std): c = c.sub(m).div(s) return x def de_normalize(x, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): new_x = [] for tem_x in x: tem_new_x = [] for c, m, s in zip(tem_x, mean, std): tem_new_x.append(c.mul(s).add(s).unsqueeze(0)) new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0)) new_x = torch.cat(new_x, dim=0) return new_x # ################################################### # direct test with gen # ################################################### print('Direct Test') mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json') direct_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), standard_transforms.ToTensor(), standard_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_dataset_direct = Cityscapes16DataSetLMDB( args.data_dir_val, args.data_list_val, transform=direct_input_transform, target_transform=target_transform, ) val_loader_direct = data.DataLoader(val_dataset_direct, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) class NewModel(object): def __init__(self, gen_net, val_net): self.gen_net = gen_net self.val_net = val_net def __call__(self, x): x = de_normalize(self.gen_net(x)) new_x = normalize(x) out = self.val_net(new_x) return out def eval(self): self.gen_net.eval() self.val_net.eval() new_model = NewModel(gen_model, net) print('Test with Gen') mean_iu = test_miou(new_model, val_loader_direct, upsample, './dataset/info.json') # return num_batches = len(tgt_loader) highest = 0 for epoch in range(args.num_epoch): loss_rec = AverageMeter() data_time_rec = AverageMeter() batch_time_rec = AverageMeter() tem_time = time.time() for batch_index, batch_data in enumerate(tgt_loader): iteration = batch_index + 1 + epoch * num_batches net.train() net_static.eval() # fine-tune use eval img, _, name = batch_data img = img.cuda() data_time_rec.update(time.time() - tem_time) with torch.no_grad(): gen_output = gen_model(img) gen_seg_output_logits = net_static( normalize(de_normalize(gen_output))) ori_seg_output_logits = net(normalize(de_normalize(img))) prob = torch.nn.Softmax(dim=1) max_value, label = torch.max(prob(gen_seg_output_logits), dim=1) label_mask = torch.zeros(label.shape, dtype=torch.uint8).cuda() for tem_label in range(19): tem_mask = label == tem_label if torch.sum(tem_mask) < 5: continue value_vec = max_value[tem_mask] large_value = torch.topk( value_vec, int(args.percent * value_vec.shape[0]))[0][0] large_mask = max_value > large_value label_mask = label_mask | (tem_mask & large_mask) label[label_mask] = 255 # loss = criterion(ori_seg_output_logits, gen_seg_output_logits) loss = criterion(ori_seg_output_logits, label) optimizer.zero_grad() loss.backward() optimizer.step() loss_rec.update(loss.item()) writer.add_scalar('A_seg_loss', loss.item(), iteration) batch_time_rec.update(time.time() - tem_time) tem_time = time.time() if (batch_index + 1) % args.print_freq == 0: print( f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t' f'Time: {batch_time_rec.avg:.2f} ' f'Data: {data_time_rec.avg:.2f} ' f'Loss: {loss_rec.avg:.2f}') if iteration % args.checkpoint_freq == 0: mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json', print_results=False) if mean_iu > highest: torch.save( net.module.state_dict(), os.path.join(args.save_path_prefix, 'cityscapes_best_fcn.pth')) highest = mean_iu print(f'save fcn model with {mean_iu:.2%}') print(('-' * 100 + '\n') * 3) print('>' * 50 + 'Final Model') net.module.load_state_dict( torch.load( os.path.join(args.save_path_prefix, 'cityscapes_best_fcn.pth'))) mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json') writer.close()
def main(): """Create the model and start the training.""" args = get_arguments() w, h = map(int, args.input_size.split(',')) w_target, h_target = map(int, args.input_size_target.split(',')) # Create network if args.bn_sync: print('Using Sync BN') deeplabv3.BatchNorm2d = partial(InPlaceABNSync, activation='none') net = get_deeplabV3(args.num_classes, args.model_path_prefix) if not args.bn_sync: net.freeze_bn() net = torch.nn.DataParallel(net) net = net.cuda() mean_std = ([104.00698793, 116.66876762, 122.67891434], [1.0, 1.0, 1.0]) train_joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), ]) input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std), ]) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std), ]) target_transform = extended_transforms.MaskToTensor() # show img restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.Lambda(lambda x: x.div_(255)), standard_transforms.ToPILImage(), extended_transforms.FlipChannels(), ]) visualize = standard_transforms.ToTensor() if '5' in args.data_dir: src_dataset = GTA5DataSetLMDB( args.data_dir, args.data_list, joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform, ) else: src_dataset = CityscapesDataSetLMDB( args.data_dir, args.data_list, joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform, ) src_loader = data.DataLoader(src_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) tgt_val_dataset = CityscapesDataSetLMDB( args.data_dir_target, args.data_list_target, # no val resize # joint_transform=val_joint_transform, transform=val_input_transform, target_transform=target_transform, ) tgt_val_loader = data.DataLoader( tgt_val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, ) # freeze bn for module in net.module.modules(): if isinstance(module, torch.nn.BatchNorm2d): for param in module.parameters(): param.requires_grad = False optimizer = optim.SGD( [{ 'params': filter(lambda p: p.requires_grad, net.module.parameters()), 'lr': args.learning_rate }], lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) # optimizer = optim.Adam( # net.parameters(), lr=args.learning_rate, # weight_decay=args.weight_decay # ) # interp = partial( # nn.functional.interpolate, # size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True # ) # interp_tgt = partial( # nn.functional.interpolate, # size=(h_target, w_target), mode='bilinear', align_corners=True # ) upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear') n_class = args.num_classes # criterion = torch.nn.CrossEntropyLoss( # ignore_index=255, reduction='sum') # criterion = torch.nn.CrossEntropyLoss( # ignore_index=255, size_average=True # ) criterion = CriterionDSN(ignore_index=255, # size_average=False ) num_batches = len(src_loader) max_iter = args.iterations i_iter = 0 highest_miu = 0 while True: cls_loss_rec = AverageMeter() aug_loss_rec = AverageMeter() mask_rec = AverageMeter() confidence_rec = AverageMeter() miu_rec = AverageMeter() data_time_rec = AverageMeter() batch_time_rec = AverageMeter() # load_time_rec = AverageMeter() # trans_time_rec = AverageMeter() tem_time = time.time() for batch_index, src_data in enumerate(src_loader): i_iter += 1 lr = adjust_learning_rate(args, optimizer, i_iter, max_iter) net.train() optimizer.zero_grad() # train with source # src_images, src_label, src_img_name, (load_time, trans_time) = src_data src_images, src_label, src_img_name = src_data src_images = src_images.cuda() src_label = src_label.cuda() data_time_rec.update(time.time() - tem_time) src_output = net(src_images) # src_output = interp(src_output) # Segmentation Loss cls_loss_value = criterion(src_output, src_label) total_loss = cls_loss_value total_loss.backward() optimizer.step() src_output = torch.nn.functional.upsample(input=src_output[0], size=(h, w), mode='bilinear', align_corners=True) _, predict_labels = torch.max(src_output, 1) lbl_pred = predict_labels.detach().cpu().numpy() lbl_true = src_label.detach().cpu().numpy() _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19) cls_loss_rec.update(cls_loss_value.detach_().item()) miu_rec.update(mean_iu) # load_time_rec.update(torch.mean(load_time).item()) # trans_time_rec.update(torch.mean(trans_time).item()) batch_time_rec.update(time.time() - tem_time) tem_time = time.time() if i_iter % args.print_freq == 0: print( # f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t' f'Iter: [{i_iter}/{max_iter}]\t' f'Time: {batch_time_rec.avg:.2f} ' f'Data: {data_time_rec.avg:.2f} ' # f'Load: {load_time_rec.avg:.2f} ' # f'Trans: {trans_time_rec.avg:.2f} ' f'Mean iu: {miu_rec.avg*100:.1f} ' f'CLS: {cls_loss_rec.avg:.2f}') if i_iter % args.eval_freq == 0: miu = test_miou(net, tgt_val_loader, upsample, './dataset/info.json') if miu > highest_miu: torch.save( net.module.state_dict(), osp.join(args.snapshot_dir, f'{i_iter:d}_{miu*1000:.0f}.pth')) highest_miu = miu print(f'>>>>>>>>>Learning Rate {lr}<<<<<<<<<') if i_iter == max_iter: return
def main(): net = PSPNet(19) net.load_pretrained_model( model_path='./Caffe-PSPNet/pspnet101_cityscapes.caffemodel') for param in net.parameters(): param.requires_grad = False net.cbr_final = conv2DBatchNormRelu(4096, 128, 3, 1, 1, False) net.dropout = nn.Dropout2d(p=0.1, inplace=True) net.classification = nn.Conv2d(128, kitti_binary.num_classes, 1, 1, 0) # Find total parameters and trainable parameters total_params = sum(p.numel() for p in net.parameters()) print(f'{total_params:,} total parameters.') total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad) print(f'{total_trainable_params:,} training parameters.') if len(args['snapshot']) == 0: # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth'))) args['best_record'] = { 'epoch': 0, 'iter': 0, 'val_loss': 1e10, 'accu': 0 } else: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = { 'epoch': int(split_snapshot[1]), 'iter': int(split_snapshot[3]), 'val_loss': float(split_snapshot[5]), 'accu': float(split_snapshot[7]) } net.cuda(args['gpu']).train() mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(args['longer_size']), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip() ]) train_input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() train_set = kitti_binary.KITTI(mode='train', joint_transform=train_joint_transform, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) val_set = kitti_binary.KITTI(mode='val', transform=val_input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False) criterion = nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 1.05)).cuda( args['gpu']) optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') train(train_loader, net, criterion, optimizer, args, val_loader)