def main(): net = U_Net(img_ch=1, num_classes=3).to(device) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(384), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip() ]) center_crop = joint_transforms.CenterCrop(crop_size) train_input_transform = extended_transforms.ImgToTensor() target_transform = extended_transforms.MaskToTensor() make_dataset_fn = bladder.make_dataset_v2 train_set = bladder.Bladder(data_path, 'train', joint_transform=train_joint_transform, center_crop=center_crop, transform=train_input_transform, target_transform=target_transform, make_dataset_fn=make_dataset_fn) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) if loss_name == 'dice_': criterion = SoftDiceLossV2(activation='sigmoid', num_classes=3).to(device) elif loss_name == 'bcew_': criterion = nn.BCEWithLogitsLoss().to(device) optimizer = optim.Adam(net.parameters(), lr=1e-4) train(train_loader, net, criterion, optimizer, n_epoch, 0)
def main(args): writer = SummaryWriter(log_dir=args.tensorboard_log_dir) w, h = map(int, args.input_size.split(',')) joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), joint_transforms.RandomHorizontallyFlip(), ]) normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) src_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), ]) tgt_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.ToPILImage() src_dataset = GTA5DataSetLMDB( args.data_dir, args.data_list, joint_transform=joint_transform, transform=src_input_transform, target_transform=target_transform, ) src_loader = data.DataLoader( src_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True ) tgt_dataset = CityscapesDataSetLMDB( args.data_dir_target, args.data_list_target, joint_transform=joint_transform, transform=tgt_input_transform, target_transform=target_transform, ) tgt_loader = data.DataLoader( tgt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True ) val_dataset = CityscapesDataSetLMDB( args.data_dir_val, args.data_list_val, transform=val_input_transform, target_transform=target_transform, ) val_loader = data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False ) style_trans = StyleTrans(args) style_trans.train(src_loader, tgt_loader, val_loader, writer) writer.close()
def transform_image_and_mask_tt(self, image, mask, angle=None, crop_size=None): assert self.use_iaa == False transforms = [JT.RandomHorizontallyFlip()] if crop_size is not None: transforms.append(JT.RandomCrop(size=crop_size)) if angle is not None: transforms.append(JT.RandomRotate(degree=angle)) jt_random = JT.RandomOrderApply(transforms) jt_transform = JT.Compose([ JT.ToPILImage(), jt_random, JT.ToNumpy(), ]) return jt_transform(image, mask)
def train_with_correspondences(save_folder, startnet, args): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") check_mkdir(save_folder) writer = SummaryWriter(save_folder) # Network and weight loading model_config = model_configs.PspnetCityscapesConfig() net = model_config.init_network().to(device) if args['snapshot'] == 'latest': args['snapshot'] = get_latest_network_name(save_folder) if len(args['snapshot']) == 0: # If start from beginning state_dict = torch.load(startnet) # needed since we slightly changed the structure of the network in # pspnet state_dict = rename_keys_to_match(state_dict) net.load_state_dict(state_dict) # load original weights start_iter = 0 args['best_record'] = { 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: # If continue training print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(save_folder, args['snapshot']))) # load weights split_snapshot = args['snapshot'].split('_') start_iter = int(split_snapshot[1]) with open(os.path.join(save_folder, 'bestval.txt')) as f: best_val_dict_str = f.read() args['best_record'] = eval(best_val_dict_str.rstrip()) net.train() freeze_bn(net) # Data loading setup if args['corr_set'] == 'rc': corr_set_config = data_configs.RobotcarConfig() elif args['corr_set'] == 'cmu': corr_set_config = data_configs.CmuConfig() sliding_crop_im = joint_transforms.SlidingCropImageOnly( 713, args['stride_rate']) input_transform = model_config.input_transform pre_validation_transform = model_config.pre_validation_transform target_transform = extended_transforms.MaskToTensor() train_joint_transform_seg = joint_transforms.Compose([ joint_transforms.Resize(1024), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomCrop(713) ]) train_joint_transform_corr = corr_transforms.Compose([ corr_transforms.CorrResize(1024), corr_transforms.CorrRandomCrop(713) ]) # keep list of segmentation loaders and validators seg_loaders = list() validators = list() # Correspondences corr_set = correspondences.Correspondences( corr_set_config.correspondence_path, corr_set_config.correspondence_im_path, input_size=(713, 713), mean_std=model_config.mean_std, input_transform=input_transform, joint_transform=train_joint_transform_corr) corr_loader = DataLoader(corr_set, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) # Cityscapes Training c_config = data_configs.CityscapesConfig() seg_set_cs = cityscapes.CityScapes( c_config.train_im_folder, c_config.train_seg_folder, c_config.im_file_ending, c_config.seg_file_ending, id_to_trainid=c_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) seg_loader_cs = DataLoader(seg_set_cs, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(seg_loader_cs) # Cityscapes Validation val_set_cs = cityscapes.CityScapes( c_config.val_im_folder, c_config.val_seg_folder, c_config.im_file_ending, c_config.seg_file_ending, id_to_trainid=c_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) val_loader_cs = DataLoader(val_set_cs, batch_size=1, num_workers=args['n_workers'], shuffle=False) validator_cs = Validator(val_loader_cs, n_classes=c_config.n_classes, save_snapshot=False, extra_name_str='Cityscapes') validators.append(validator_cs) # Vistas Training and Validation if args['include_vistas']: v_config = data_configs.VistasConfig( use_subsampled_validation_set=True, use_cityscapes_classes=True) seg_set_vis = cityscapes.CityScapes( v_config.train_im_folder, v_config.train_seg_folder, v_config.im_file_ending, v_config.seg_file_ending, id_to_trainid=v_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) seg_loader_vis = DataLoader(seg_set_vis, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(seg_loader_vis) val_set_vis = cityscapes.CityScapes( v_config.val_im_folder, v_config.val_seg_folder, v_config.im_file_ending, v_config.seg_file_ending, id_to_trainid=v_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) val_loader_vis = DataLoader(val_set_vis, batch_size=1, num_workers=args['n_workers'], shuffle=False) validator_vis = Validator(val_loader_vis, n_classes=v_config.n_classes, save_snapshot=False, extra_name_str='Vistas') validators.append(validator_vis) else: seg_loader_vis = None map_validator = None # Extra Training extra_seg_set = cityscapes.CityScapes( corr_set_config.train_im_folder, corr_set_config.train_seg_folder, corr_set_config.im_file_ending, corr_set_config.seg_file_ending, id_to_trainid=corr_set_config.id_to_trainid, joint_transform=train_joint_transform_seg, sliding_crop=None, transform=input_transform, target_transform=target_transform) extra_seg_loader = DataLoader(extra_seg_set, batch_size=args['train_batch_size'], num_workers=args['n_workers'], shuffle=True) seg_loaders.append(extra_seg_loader) # Extra Validation extra_val_set = cityscapes.CityScapes( corr_set_config.val_im_folder, corr_set_config.val_seg_folder, corr_set_config.im_file_ending, corr_set_config.seg_file_ending, id_to_trainid=corr_set_config.id_to_trainid, sliding_crop=sliding_crop_im, transform=input_transform, target_transform=target_transform, transform_before_sliding=pre_validation_transform) extra_val_loader = DataLoader(extra_val_set, batch_size=1, num_workers=args['n_workers'], shuffle=False) extra_validator = Validator(extra_val_loader, n_classes=corr_set_config.n_classes, save_snapshot=True, extra_name_str='Extra') validators.append(extra_validator) # Loss setup if args['corr_loss_type'] == 'class': corr_loss_fct = CorrClassLoss(input_size=[713, 713]) else: corr_loss_fct = FeatureLoss( input_size=[713, 713], loss_type=args['corr_loss_type'], feat_dist_threshold_match=args['feat_dist_threshold_match'], feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'], n_not_matching=0) seg_loss_fct = torch.nn.CrossEntropyLoss( reduction='elementwise_mean', ignore_index=cityscapes.ignore_label).to(device) # Optimizer setup optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and param.requires_grad ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and param.requires_grad ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load(os.path.join(save_folder, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] open(os.path.join(save_folder, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') if len(args['snapshot']) == 0: f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1) else: clean_log_before_continuing(os.path.join(save_folder, 'log.log'), start_iter) f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1) ########################################################################## # # MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS # ########################################################################## softm = torch.nn.Softmax2d() val_iter = 0 train_corr_loss = AverageMeter() train_seg_cs_loss = AverageMeter() train_seg_extra_loss = AverageMeter() train_seg_vis_loss = AverageMeter() seg_loss_meters = list() seg_loss_meters.append(train_seg_cs_loss) if args['include_vistas']: seg_loss_meters.append(train_seg_vis_loss) seg_loss_meters.append(train_seg_extra_loss) curr_iter = start_iter for i in range(args['max_iter']): optimizer.param_groups[0]['lr'] = 2 * args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] optimizer.param_groups[1]['lr'] = args['lr'] * ( 1 - float(curr_iter) / args['max_iter'])**args['lr_decay'] ####################################################################### # SEGMENTATION UPDATE STEP ####################################################################### # for si, seg_loader in enumerate(seg_loaders): # get segmentation training sample inputs, gts = next(iter(seg_loader)) slice_batch_pixel_size = inputs.size(0) * inputs.size( 2) * inputs.size(3) inputs = inputs.to(device) gts = gts.to(device) optimizer.zero_grad() outputs, aux = net(inputs) main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts) aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts) loss = main_loss + 0.4 * aux_loss loss.backward() optimizer.step() seg_loss_meters[si].update(main_loss.item(), slice_batch_pixel_size) ####################################################################### # CORRESPONDENCE UPDATE STEP ####################################################################### if args['corr_loss_weight'] > 0 and args[ 'n_iterations_before_corr_loss'] < curr_iter: img_ref, img_other, pts_ref, pts_other, weights = next( iter(corr_loader)) # Transfer data to device # img_ref is from the "good" sequence with generally better # segmentation results img_ref = img_ref.to(device) img_other = img_other.to(device) pts_ref = [p.to(device) for p in pts_ref] pts_other = [p.to(device) for p in pts_other] weights = [w.to(device) for w in weights] # Forward pass if args['corr_loss_type'] == 'hingeF': # Works on features net.output_all = True with torch.no_grad(): output_feat_ref, aux_feat_ref, output_ref, aux_ref = net( img_ref) output_feat_other, aux_feat_other, output_other, aux_other = net( img_other ) # output1 must be last to backpropagate derivative correctly net.output_all = False else: # Works on class probs with torch.no_grad(): output_ref, aux_ref = net(img_ref) if args['corr_loss_type'] != 'hingeF' and args[ 'corr_loss_type'] != 'hingeC': output_ref = softm(output_ref) aux_ref = softm(aux_ref) # output1 must be last to backpropagate derivative correctly output_other, aux_other = net(img_other) if args['corr_loss_type'] != 'hingeF' and args[ 'corr_loss_type'] != 'hingeC': output_other = softm(output_other) aux_other = softm(aux_other) # Correspondence filtering pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample( output_ref, output_other, pts_ref, pts_other, weights, remove_same_class=args['remove_same_class'], remove_classes=args['classes_to_ignore']) pts_ref_orig = [ p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig) if b.item() > 0 ] pts_other_orig = [ p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig) if b.item() > 0 ] weights_orig = [ p for b, p in zip(batch_inds_to_keep_orig, weights_orig) if b.item() > 0 ] if args['corr_loss_type'] == 'hingeF': # remove entire samples if needed output_vals_ref = output_feat_ref[batch_inds_to_keep_orig] output_vals_other = output_feat_other[batch_inds_to_keep_orig] else: # remove entire samples if needed output_vals_ref = output_ref[batch_inds_to_keep_orig] output_vals_other = output_other[batch_inds_to_keep_orig] pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample( aux_ref, aux_other, pts_ref, pts_other, weights, remove_same_class=args['remove_same_class'], remove_classes=args['classes_to_ignore']) pts_ref_aux = [ p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux) if b.item() > 0 ] pts_other_aux = [ p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux) if b.item() > 0 ] weights_aux = [ p for b, p in zip(batch_inds_to_keep_aux, weights_aux) if b.item() > 0 ] if args['corr_loss_type'] == 'hingeF': # remove entire samples if needed aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig] aux_vals_other = aux_feat_other[batch_inds_to_keep_orig] else: # remove entire samples if needed aux_vals_ref = aux_ref[batch_inds_to_keep_aux] aux_vals_other = aux_other[batch_inds_to_keep_aux] optimizer.zero_grad() # correspondence loss if output_vals_ref.size(0) > 0: loss_corr_hr = corr_loss_fct(output_vals_ref, output_vals_other, pts_ref_orig, pts_other_orig, weights_orig) else: loss_corr_hr = 0 * output_vals_other.sum() if aux_vals_ref.size(0) > 0: loss_corr_aux = corr_loss_fct( aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux, weights_aux) # use output from img1 as "reference" else: loss_corr_aux = 0 * aux_vals_other.sum() loss_corr = args['corr_loss_weight'] * \ (loss_corr_hr + 0.4 * loss_corr_aux) loss_corr.backward() optimizer.step() train_corr_loss.update(loss_corr.item()) ####################################################################### # LOGGING ETC ####################################################################### curr_iter += 1 val_iter += 1 writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg, curr_iter) writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg, curr_iter) writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg, curr_iter) writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter) writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter) if (i + 1) % args['print_freq'] == 0: str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % ( curr_iter, len(corr_loader), train_corr_loss.avg, train_seg_cs_loss.avg, train_seg_vis_loss.avg, train_seg_extra_loss.avg, optimizer.param_groups[1]['lr']) print(str2write) f_handle.write(str2write + "\n") if val_iter >= args['val_interval']: val_iter = 0 for validator in validators: validator.run(net, optimizer, args, curr_iter, save_folder, f_handle, writer=writer) # Post training f_handle.close() writer.close()
def main(args): writer = SummaryWriter(log_dir=args.tensorboard_log_dir) w, h = map(int, args.input_size.split(',')) w_target, h_target = map(int, args.input_size_target.split(',')) joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), joint_transforms.RandomHorizontallyFlip(), ]) normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.ToPILImage() if '5' in args.data_dir: dataset = GTA5DataSetLMDB( args.data_dir, args.data_list, joint_transform=joint_transform, transform=input_transform, target_transform=target_transform, ) else: dataset = CityscapesDataSetLMDB( args.data_dir, args.data_list, joint_transform=joint_transform, transform=input_transform, target_transform=target_transform, ) loader = data.DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True ) val_dataset = CityscapesDataSetLMDB( args.data_dir_target, args.data_list_target, # joint_transform=joint_transform, transform=input_transform, target_transform=target_transform ) val_loader = data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True ) upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear', align_corners=True) net = PSP( nclass = args.n_classes, backbone='resnet101', root=args.model_path_prefix, norm_layer=BatchNorm2d, ) params_list = [ {'params': net.pretrained.parameters(), 'lr': args.learning_rate}, {'params': net.head.parameters(), 'lr': args.learning_rate*10}, {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10}, ] optimizer = torch.optim.SGD(params_list, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255) # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255) net = DataParallelModel(net).cuda() criterion = DataParallelCriterion(criterion).cuda() logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train') scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate, args.num_epoch, len(loader), logger=logger, lr_step=args.lr_step) net_eval = Eval(net) num_batches = len(loader) best_pred = 0.0 for epoch in range(args.num_epoch): loss_rec = AverageMeter() data_time_rec = AverageMeter() batch_time_rec = AverageMeter() tem_time = time.time() for batch_index, batch_data in enumerate(loader): scheduler(optimizer, batch_index, epoch, best_pred) show_fig = (batch_index+1) % args.show_img_freq == 0 iteration = batch_index+1+epoch*num_batches net.train() img, label, name = batch_data img = img.cuda() label_cuda = label.cuda() data_time_rec.update(time.time()-tem_time) output = net(img) loss = criterion(output, label_cuda) optimizer.zero_grad() loss.backward() optimizer.step() loss_rec.update(loss.item()) writer.add_scalar('A_seg_loss', loss.item(), iteration) batch_time_rec.update(time.time()-tem_time) tem_time = time.time() if (batch_index+1) % args.print_freq == 0: print( f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t' f'Time: {batch_time_rec.avg:.2f} ' f'Data: {data_time_rec.avg:.2f} ' f'Loss: {loss_rec.avg:.2f}' ) # if show_fig: # # base_lr = optimizer.param_groups[0]["lr"] # output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu() # # fig, axes = plt.subplots(2, 1, figsize=(12, 14)) # # axes = axes.flat # # axes[0].imshow(colorize_mask(output.numpy())) # # axes[0].set_title(name[0]) # # axes[1].imshow(colorize_mask(label[0, ...].numpy())) # # axes[1].set_title(f'seg_true_{base_lr:.6f}') # # writer.add_figure('A_seg', fig, iteration) # output_mask = np.asarray(colorize_mask(output.numpy())) # label = np.asarray(colorize_mask(label[0,...].numpy())) # image_out = np.concatenate([output_mask, label]) # writer.add_image('A_seg', image_out, iteration) mean_iu = test_miou(net_eval, val_loader, upsample, './style_seg/dataset/info.json') torch.save( net.module.state_dict(), os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth') ) writer.close()
def main(train_args): net = FCN8s(num_classes=11) if len(train_args['snapshot']) == 0: curr_epoch = 1 train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} else: print('training resumes from ' + train_args['snapshot']) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_joint_transform = joint_transforms.Compose([ # joint_transforms.Scale(short_size), joint_transforms.Scale(2000), # joint_transforms.RandomCrop(args['input_size']), joint_transforms.RandomHorizontallyFlip() ]) input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage(), ]) visualize = standard_transforms.Compose([ standard_transforms.Scale(400), standard_transforms.CenterCrop(400), standard_transforms.ToTensor() ]) train_set = PRIMA('train', joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=1, num_workers=0, shuffle=True) it = iter(train_loader) first = next(it) import pdb; pdb.set_trace() criterion = CrossEntropyLoss2d(size_average=False).cuda() optimizer = optim.Adam([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * train_args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} ], betas=(train_args['momentum'], 0.999)) if len(train_args['snapshot']) > 0: optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] optimizer.param_groups[1]['lr'] = train_args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n') scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True) for epoch in range(curr_epoch, train_args['epoch_num'] + 1): train(train_loader, net, criterion, optimizer, epoch, train_args)
def main(): net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda() if len(args['snapshot']) == 0: curr_epoch = 1 args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} else: print('training resumes from ' + args['snapshot']) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} net.train() mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) short_size = int(min(args['input_size']) / 0.875) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(short_size), joint_transforms.RandomCrop(args['input_size']), joint_transforms.RandomHorizontallyFlip() ]) val_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(short_size), joint_transforms.CenterCrop(args['input_size']) ]) input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.Lambda(lambda x: x.div_(255)), standard_transforms.ToPILImage(), extended_transforms.FlipChannels() ]) visualize = standard_transforms.ToTensor() train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False) criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda() optimizer = optim.Adam([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': args['lr'], 'weight_decay': args['weight_decay']} ], betas=(args['momentum'], 0.999)) if len(args['snapshot']) > 0: optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10, verbose=True) for epoch in range(curr_epoch, args['epoch_num'] + 1): train(train_loader, net, criterion, optimizer, epoch, args) val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize) scheduler.step(val_loss)
def train_without_ignite(model, loss, batch_size, img_size, epochs, lr, num_workers, optimizer, logger, gray_image=False, scheduler=None, viz=True): import visdom from utils.metrics import Accuracy, IoU DEFAULT_PORT = 8097 DEFAULT_HOSTNAME = "http://localhost" if viz: vis = visdom.Visdom(port=DEFAULT_PORT, server=DEFAULT_HOSTNAME) device = 'cuda' if torch.cuda.is_available() else 'cpu' data_loader = {} joint_transforms = jnt_trnsf.Compose([ jnt_trnsf.RandomCrop(img_size), jnt_trnsf.RandomRotate(5), jnt_trnsf.RandomHorizontallyFlip() ]) train_image_transforms = std_trnsf.Compose([ std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05), std_trnsf.ToTensor(), std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) test_joint_transforms = jnt_trnsf.Compose([jnt_trnsf.Safe32Padding()]) test_image_transforms = std_trnsf.Compose([ std_trnsf.ToTensor(), std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) mask_transforms = std_trnsf.Compose([std_trnsf.ToTensor()]) data_loader['train'] = get_loader(dataset='figaro', train=True, joint_transforms=joint_transforms, image_transforms=train_image_transforms, mask_transforms=mask_transforms, batch_size=batch_size, shuffle=True, num_workers=num_workers, gray_image=gray_image) data_loader['test'] = get_loader(dataset='figaro', train=False, joint_transforms=test_joint_transforms, image_transforms=test_image_transforms, mask_transforms=mask_transforms, batch_size=1, shuffle=True, num_workers=num_workers, gray_image=gray_image) for epoch in range(epochs): for phase in ['train', 'test']: if phase == 'train': model.train(True) else: prev_grad_state = torch.is_grad_enabled() torch.set_grad_enabled(False) model.train(False) running_loss = 0.0 for i, data in enumerate(tqdm(data_loader[phase], file=sys.stdout)): if i == len(data_loader[phase]) - 1: break data_ = [ t.to(device) if isinstance(t, torch.Tensor) else t for t in data ] if gray_image: img, mask, gray = data_ else: img, mask = data_ model.zero_grad() pred_mask = model(img) if gray_image: l = loss(pred_mask, mask, gray) else: l = loss(pred_mask, mask) if phase == 'train': l.backward() optimizer.step() running_loss += l.item() epoch_loss = running_loss / len(data_loader[phase]) if phase == 'train': logger.info( f"Training Results - Epoch: {epoch} Avg-loss: {epoch_loss:.3f}" ) if viz: vis.images( [ np.clip(pred_mask.detach().cpu().numpy()[0], 0, 1), mask.detach().cpu().numpy()[0] ], opts=dict(title=f'pred img for {epoch}-th iter')) if phase == 'test': if viz: vis.images( [ np.clip(pred_mask.detach().cpu().numpy()[0], 0, 1), mask.detach().cpu().numpy()[0] ], opts=dict(title=f'pred img for {epoch}-th iter')) logger.info( f"Test Results - Epoch: {epoch} Avg-loss: {epoch_loss:.3f}" ) if scheduler: scheduler.step(epoch_loss) torch.set_grad_enabled(prev_grad_state)
def main(train_args): net = PSPNet(num_classes=voc.num_classes).cuda() if len(train_args['snapshot']) == 0: curr_epoch = 1 train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} else: print ('training resumes from ' + train_args['snapshot']) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])} net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_simul_transform = simul_transforms.Compose([ simul_transforms.RandomSized(train_args['input_size']), simul_transforms.RandomRotate(10), simul_transforms.RandomHorizontallyFlip() ]) val_simul_transform = simul_transforms.Scale(train_args['input_size']) train_input_transform = standard_transforms.Compose([ extended_transforms.RandomGaussianBlur(), standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage(), ]) visualize = standard_transforms.Compose([ standard_transforms.Scale(400), standard_transforms.CenterCrop(400), standard_transforms.ToTensor() ]) train_set = voc.VOC('train', joint_transform=train_simul_transform, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True) val_set = voc.VOC('val', joint_transform=val_simul_transform, transform=val_input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False) criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda() optimizer = optim.SGD([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * train_args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']} ], momentum=train_args['momentum']) if len(train_args['snapshot']) > 0: optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['lr'] optimizer.param_groups[1]['lr'] = train_args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n') train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
def main(args): writer = SummaryWriter(log_dir=args.tensorboard_log_dir) w, h = map(int, args.input_size.split(',')) w_target, h_target = map(int, args.input_size_target.split(',')) joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), joint_transforms.RandomHorizontallyFlip(), ]) normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) tgt_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.ToPILImage() if args.seg_net == 'fcn': mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std), ]) else: normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) val_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) tgt_dataset = Cityscapes16DataSetLMDB( args.data_dir_target, args.data_list_target, joint_transform=joint_transform, transform=tgt_input_transform, target_transform=target_transform, ) tgt_loader = data.DataLoader(tgt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True) val_dataset = Cityscapes16DataSetLMDB( args.data_dir_val, args.data_list_val, transform=val_input_transform, target_transform=target_transform, ) val_loader = data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, ) upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear', align_corners=True) if args.seg_net == 'fcn': net = FCN8s(args.n_classes, pretrained=False) net_static = FCN8s(args.n_classes, pretrained=False) file_name = os.path.join(args.resume, args.fcn_name) # for name, param in net.named_parameters(): # if 'feat' not in name: # param.requires_grad = False elif args.seg_net == 'deeplab_ibn': deeplab = resnet101_ibn_a_deeplab() file_name = os.path.join(args.resume, 'deeplab_ibn.pth') net.load_state_dict(torch.load(file_name)) net_static.load_state_dict(torch.load(file_name)) for param in net_static.parameters(): param.requires_grad = False optimizer = torch.optim.SGD(net.parameters(), args.learning_rate, args.momentum) net = torch.nn.DataParallel(net.cuda()) net_static = torch.nn.DataParallel(net_static.cuda()) # criterion = torch.nn.MSELoss() # criterion = torch.nn.SmoothL1Loss() criterion = torch.nn.CrossEntropyLoss(ignore_index=255) gen_model = define_G() gen_model.load_state_dict( torch.load(os.path.join(args.resume, args.gen_name))) gen_model.eval() for param in gen_model.parameters(): param.requires_grad = False gen_model = torch.nn.DataParallel(gen_model.cuda()) # for seg net def normalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): if args.seg_net == 'fcn': mean = [103.939, 116.779, 123.68] flip_x = torch.cat( [x[:, 2 - i, :, :].unsqueeze(1) for i in range(3)], dim=1, ) new_x = [] for tem_x in flip_x: tem_new_x = [] for c, m in zip(tem_x, mean): tem_new_x.append(c.mul(255.0).sub(m).unsqueeze(0)) new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0)) new_x = torch.cat(new_x, dim=0) return new_x else: for tem_x in x: for c, m, s in zip(tem_x, mean, std): c = c.sub(m).div(s) return x def de_normalize(x, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): new_x = [] for tem_x in x: tem_new_x = [] for c, m, s in zip(tem_x, mean, std): tem_new_x.append(c.mul(s).add(s).unsqueeze(0)) new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0)) new_x = torch.cat(new_x, dim=0) return new_x # ################################################### # direct test with gen # ################################################### print('Direct Test') mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json') direct_input_transform = standard_transforms.Compose([ extended_transforms.FreeScale((h, w)), standard_transforms.ToTensor(), standard_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) val_dataset_direct = Cityscapes16DataSetLMDB( args.data_dir_val, args.data_list_val, transform=direct_input_transform, target_transform=target_transform, ) val_loader_direct = data.DataLoader(val_dataset_direct, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) class NewModel(object): def __init__(self, gen_net, val_net): self.gen_net = gen_net self.val_net = val_net def __call__(self, x): x = de_normalize(self.gen_net(x)) new_x = normalize(x) out = self.val_net(new_x) return out def eval(self): self.gen_net.eval() self.val_net.eval() new_model = NewModel(gen_model, net) print('Test with Gen') mean_iu = test_miou(new_model, val_loader_direct, upsample, './dataset/info.json') # return num_batches = len(tgt_loader) highest = 0 for epoch in range(args.num_epoch): loss_rec = AverageMeter() data_time_rec = AverageMeter() batch_time_rec = AverageMeter() tem_time = time.time() for batch_index, batch_data in enumerate(tgt_loader): iteration = batch_index + 1 + epoch * num_batches net.train() net_static.eval() # fine-tune use eval img, _, name = batch_data img = img.cuda() data_time_rec.update(time.time() - tem_time) with torch.no_grad(): gen_output = gen_model(img) gen_seg_output_logits = net_static( normalize(de_normalize(gen_output))) ori_seg_output_logits = net(normalize(de_normalize(img))) prob = torch.nn.Softmax(dim=1) max_value, label = torch.max(prob(gen_seg_output_logits), dim=1) label_mask = torch.zeros(label.shape, dtype=torch.uint8).cuda() for tem_label in range(19): tem_mask = label == tem_label if torch.sum(tem_mask) < 5: continue value_vec = max_value[tem_mask] large_value = torch.topk( value_vec, int(args.percent * value_vec.shape[0]))[0][0] large_mask = max_value > large_value label_mask = label_mask | (tem_mask & large_mask) label[label_mask] = 255 # loss = criterion(ori_seg_output_logits, gen_seg_output_logits) loss = criterion(ori_seg_output_logits, label) optimizer.zero_grad() loss.backward() optimizer.step() loss_rec.update(loss.item()) writer.add_scalar('A_seg_loss', loss.item(), iteration) batch_time_rec.update(time.time() - tem_time) tem_time = time.time() if (batch_index + 1) % args.print_freq == 0: print( f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t' f'Time: {batch_time_rec.avg:.2f} ' f'Data: {data_time_rec.avg:.2f} ' f'Loss: {loss_rec.avg:.2f}') if iteration % args.checkpoint_freq == 0: mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json', print_results=False) if mean_iu > highest: torch.save( net.module.state_dict(), os.path.join(args.save_path_prefix, 'cityscapes_best_fcn.pth')) highest = mean_iu print(f'save fcn model with {mean_iu:.2%}') print(('-' * 100 + '\n') * 3) print('>' * 50 + 'Final Model') net.module.load_state_dict( torch.load( os.path.join(args.save_path_prefix, 'cityscapes_best_fcn.pth'))) mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json') writer.close()
def load_dataset(self): # set mean and std value from ImageNet dataset if self.args.input_normalize: rgb_mean, rgb_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] else: rgb_mean, rgb_std = [122.675, 116.669, 104.008], [58.395, 57.12, 57.375] # train joint transforms train_jts = joint_transforms.Compose( [ # joint_transforms.ElasticTransform(), joint_transforms.Resize(self.args.max_size), joint_transforms.Apply(tv_transforms.Lambda( lambda img: img.astype(np.float) / 255), th=[False, True]), joint_transforms.RandomCrop(self.args.crop_size, ignore_label=255), joint_transforms.RandomHorizontallyFlip(), ]) # train source transforms train_sts = tv_transforms.Compose( [ # extend_transforms.RandomGaussianBlur(blur_prob=0.1), # extend_transforms.RandomBright(), extend_transforms.ImageToTensor(self.args.input_normalize), tv_transforms.Normalize(mean=rgb_mean, std=rgb_std) ]) # train target transforms # train_tts = extend_transforms.MapToTensor() train_tts = extend_transforms.GrayImageToTensor(False) train_dataset = saliency.SaliencyMergedData( root=self.args.dataset_dir, phase='train', dataset_list=self.args.datasets, joint_transform=train_jts, source_transform=train_sts, target_transform=train_tts, ) self.train_sampler = None if self.args.weighted_sampler: self.train_sampler = sampler.DatasetWeightedSampler( train_dataset.weight_list, self.args.train_num_sampler) self.train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=self.args.batch_size, drop_last=True, num_workers=self.args.workers, shuffle=(self.train_sampler is None), pin_memory=True, sampler=self.train_sampler, ) # valid joint transforms valid_jts = joint_transforms.Compose([ joint_transforms.Resize(self.args.crop_size), # joint_transforms.RandomCrop(args.crop_size), ]) # valid source transforms valid_sts = tv_transforms.Compose([ extend_transforms.ImageToTensor(self.args.input_normalize), tv_transforms.Normalize(mean=rgb_mean, std=rgb_std) ]) # valid target transforms valid_tts = extend_transforms.GrayImageToTensor() valid_dataset = saliency.SaliencyMergedData( root=self.args.dataset_dir, phase='valid', dataset_list=self.args.datasets, joint_transform=valid_jts, source_transform=valid_sts, target_transform=valid_tts, ) self.valid_sampler = None if self.args.weighted_sampler: self.valid_sampler = sampler.DatasetWeightedSampler( valid_dataset.weight_list, self.args.valid_num_sampler) self.valid_loader = torch.utils.data.DataLoader( valid_dataset, drop_last=True, batch_size=self.args.test_batch_size, num_workers=self.args.workers, shuffle=False, pin_memory=True, sampler=self.valid_sampler, )
def main(): # args = parse_args() torch.backends.cudnn.benchmark = True os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") # # if args.seed: # random.seed(args.seed) # np.random.seed(args.seed) # torch.manual_seed(args.seed) # # if args.gpu: # torch.cuda.manual_seed_all(args.seed) seed = 63 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # if args.gpu: torch.cuda.manual_seed_all(seed) mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # train_transforms = transforms.Compose([ # transforms.RandomCrop(args['crop_size']), # transforms.RandomRotation(90), # transforms.RandomHorizontalFlip(p=0.5), # transforms.RandomVerticalFlip(p=0.5), # ]) short_size = int(min(args['input_size']) / 0.875) # val_transforms = transforms.Compose([ # transforms.Scale(short_size, interpolation=Image.NEAREST), # # joint_transforms.Scale(short_size), # transforms.CenterCrop(args['input_size']) # ]) train_joint_transform = joint_transforms.Compose([ # joint_transforms.Scale(short_size), joint_transforms.RandomCrop(args['crop_size']), joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomRotate(90) ]) val_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(short_size), joint_transforms.CenterCrop(args['input_size']) ]) input_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(*mean_std)]) target_transform = extended_transforms.MaskToTensor() restore_transform = transforms.Compose( [extended_transforms.DeNormalize(*mean_std), transforms.ToPILImage()]) visualize = transforms.ToTensor() train_set = cityscapes.CityScapes('train', joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform) # train_set = cityscapes.CityScapes('train', transform=train_transforms) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) val_set = cityscapes.CityScapes('val', joint_transform=val_joint_transform, transform=input_transform, target_transform=target_transform) # val_set = cityscapes.CityScapes('val', transform=val_transforms) val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=True) print(len(train_loader), len(val_loader)) # sdf vgg_model = VGGNet(requires_grad=True, remove_fc=True) net = FCN8s(pretrained_net=vgg_model, n_class=cityscapes.num_classes, dropout_rate=0.4) # net.apply(init_weights) criterion = nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label) optimizer = optim.Adam(net.parameters(), lr=1e-4) check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10) vgg_model = vgg_model.to(device) net = net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) if len(args['snapshot']) == 0: curr_epoch = 1 args['best_record'] = { 'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0 } else: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = { 'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9][:-4]) } criterion.to(device) for epoch in range(curr_epoch, args['epoch_num'] + 1): train(train_loader, net, device, criterion, optimizer, epoch, args) val_loss = validate(val_loader, net, device, criterion, optimizer, epoch, args, restore_transform, visualize) scheduler.step(val_loss)
def main(train_args): net = PSPNet(num_classes=voc.num_classes).cuda() if len(train_args['snapshot']) == 0: curr_epoch = 1 train_args['best_record'] = { 'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: print('training resumes from ' + train_args['snapshot']) net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 train_args['best_record'] = { 'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11]) } mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) input_transform = standard_transforms.Compose([ ToTensor(), Normalize([.485, .456, .406], [.229, .224, .225]), ]) joint_transform = joint_transforms.Compose([ joint_transforms.CenterCrop(224), # joint_transforms.Scale(2), joint_transforms.RandomHorizontallyFlip(), ]) target_transform = standard_transforms.Compose([ extended_transforms.MaskToTensor(), ]) restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage(), ]) visualize = standard_transforms.Compose([ standard_transforms.Scale(400), standard_transforms.CenterCrop(400), standard_transforms.ToTensor() ]) val_input_transform = standard_transforms.Compose([ CenterCrop(224), ToTensor(), Normalize([.485, .456, .406], [.229, .224, .225]), ]) val_target_transform = standard_transforms.Compose([ CenterCrop(224), extended_transforms.MaskToTensor(), ]) train_set = voc.VOC('train', transform=input_transform, target_transform=target_transform, joint_transform=joint_transform) train_loader = DataLoader(train_set, batch_size=4, num_workers=4, shuffle=True) val_set = voc.VOC('val', transform=val_input_transform, target_transform=val_target_transform) val_loader = DataLoader(val_set, batch_size=4, num_workers=4, shuffle=False) # criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda() criterion = torch.nn.CrossEntropyLoss(ignore_index=voc.ignore_label).cuda() optimizer = optim.SGD(net.parameters(), lr=train_args['lr'], momentum=train_args['momentum'], weight_decay=train_args['weight_decay']) check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) # open(os.path.join(ckpt_path, exp_name, 'loss_001_aux_SGD_momentum_95_random_lr_001.txt'), 'w').write(str(train_args) + '\n\n') for epoch in range(curr_epoch, train_args['epoch_num'] + 1): # adjust_learning_rate(optimizer,epoch,net,train_args) train(train_loader, net, criterion, optimizer, epoch, train_args) validate(val_loader, net, criterion, optimizer, epoch, train_args, restore_transform, visualize) adjust_learning_rate(optimizer, epoch, net, train_args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') num_filters = [16, 32] batch_size = 32 lr = 1e-4 weight_decay = 1e-5 epochs = 600 # unet setting num_conv_blocks = 1 # prob unet only num_convs_per_block = num_conv_blocks num_convs_fcomb = 1 partial_data = False latent_dim = 6 beta = 10 # isotropic = False # kaiming_normal and orthogonal initializers = {'w': 'kaiming_normal', 'b': 'normal'} # initializers = {'w':None, 'b':None} # Transforms joint_transfm = joint_transforms.Compose([joint_transforms.RandomHorizontallyFlip(), joint_transforms.RandomSizedCrop(128), joint_transforms.RandomRotate(60)]) # joint_transfm = None input_transfm = transforms.Compose([transforms.ToPILImage()]) target_transfm = transforms.Compose([transforms.ToTensor()]) # joint_transfm=None # input_transfm=None
def get_transforms(scale_size, input_size, region_size, supervised, test, al_algorithm, full_res, dataset): mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) if scale_size == 0: print('(Data loading) Not scaling the data') print('(Data loading) Random crops of ' + str(input_size) + ' in training') print('(Data loading) No crops in validation') if supervised: train_joint_transform = joint_transforms.Compose([ joint_transforms.RandomCrop(input_size), joint_transforms.RandomHorizontallyFlip() ]) else: train_joint_transform = joint_transforms.ComposeRegion([ joint_transforms.RandomCropRegion(input_size, region_size=region_size), joint_transforms.RandomHorizontallyFlip() ]) if (not test and al_algorithm == 'ralis') and not full_res: val_joint_transform = joint_transforms.Scale(1024) else: val_joint_transform = None al_train_joint_transform = joint_transforms.ComposeRegion([ joint_transforms.CropRegion(region_size, region_size=region_size), joint_transforms.RandomHorizontallyFlip() ]) else: print('(Data loading) Scaling training data: ' + str(scale_size) + ' width dimension') print('(Data loading) Random crops of ' + str(input_size) + ' in training') print('(Data loading) No crops nor scale_size in validation') if supervised: train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(scale_size), joint_transforms.RandomCrop(input_size), joint_transforms.RandomHorizontallyFlip() ]) else: train_joint_transform = joint_transforms.ComposeRegion([ joint_transforms.Scale(scale_size), joint_transforms.RandomCropRegion(input_size, region_size=region_size), joint_transforms.RandomHorizontallyFlip() ]) al_train_joint_transform = joint_transforms.ComposeRegion([ joint_transforms.Scale(scale_size), joint_transforms.CropRegion(region_size, region_size=region_size), joint_transforms.RandomHorizontallyFlip() ]) if dataset == 'gta_for_camvid': val_joint_transform = joint_transforms.ComposeRegion( [joint_transforms.Scale(scale_size)]) else: val_joint_transform = None input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() return input_transform, target_transform, train_joint_transform, val_joint_transform, al_train_joint_transform
def main(args): writer = SummaryWriter(log_dir=args.tensorboard_log_dir) w, h = map(int, args.input_size.split(',')) w_target, h_target = map(int, args.input_size_target.split(',')) joint_transform = joint_transforms.Compose([ joint_transforms.FreeScale((h, w)), joint_transforms.RandomHorizontallyFlip(), ]) normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*normalize), ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.ToPILImage() if '5' in args.data_dir: dataset = GTA5DataSetLMDB( args.data_dir, args.data_list, joint_transform=joint_transform, transform=input_transform, target_transform=target_transform, ) else: dataset = CityscapesDataSetLMDB( args.data_dir, args.data_list, joint_transform=joint_transform, transform=input_transform, target_transform=target_transform, ) loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) val_dataset = CityscapesDataSetLMDB( args.data_dir_target, args.data_list_target, # joint_transform=joint_transform, transform=input_transform, target_transform=target_transform) val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear', align_corners=True) net = resnet101_ibn_a_deeplab(args.model_path_prefix, n_classes=args.n_classes) # optimizer = get_seg_optimizer(net, args) optimizer = torch.optim.SGD(net.parameters(), args.learning_rate, args.momentum) net = torch.nn.DataParallel(net) criterion = torch.nn.CrossEntropyLoss(size_average=False, ignore_index=args.ignore_index) num_batches = len(loader) for epoch in range(args.num_epoch): loss_rec = AverageMeter() data_time_rec = AverageMeter() batch_time_rec = AverageMeter() tem_time = time.time() for batch_index, batch_data in enumerate(loader): show_fig = (batch_index + 1) % args.show_img_freq == 0 iteration = batch_index + 1 + epoch * num_batches # poly_lr_scheduler( # optimizer=optimizer, # init_lr=args.learning_rate, # iter=iteration - 1, # lr_decay_iter=args.lr_decay, # max_iter=args.num_epoch*num_batches, # power=args.poly_power, # ) net.train() # net.module.freeze_bn() img, label, name = batch_data img = img.cuda() label_cuda = label.cuda() data_time_rec.update(time.time() - tem_time) output = net(img) loss = criterion(output, label_cuda) optimizer.zero_grad() loss.backward() optimizer.step() loss_rec.update(loss.item()) writer.add_scalar('A_seg_loss', loss.item(), iteration) batch_time_rec.update(time.time() - tem_time) tem_time = time.time() if (batch_index + 1) % args.print_freq == 0: print( f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t' f'Time: {batch_time_rec.avg:.2f} ' f'Data: {data_time_rec.avg:.2f} ' f'Loss: {loss_rec.avg:.2f}') if show_fig: base_lr = optimizer.param_groups[0]["lr"] output = torch.argmax(output, dim=1).detach()[0, ...].cpu() fig, axes = plt.subplots(2, 1, figsize=(12, 14)) axes = axes.flat axes[0].imshow(colorize_mask(output.numpy())) axes[0].set_title(name[0]) axes[1].imshow(colorize_mask(label[0, ...].numpy())) axes[1].set_title(f'seg_true_{base_lr:.6f}') writer.add_figure('A_seg', fig, iteration) mean_iu = test_miou(net, val_loader, upsample, './ae_seg/dataset/info.json') torch.save( net.module.state_dict(), os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth')) writer.close()
def main(): net = FCN32VGG(num_classes=mapillary.num_classes).cuda() if len(args['snapshot']) == 0: curr_epoch = 1 args['best_record'] = { 'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = { 'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11]) } net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) short_size = int(min(args['input_size']) / 0.875) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(short_size), joint_transforms.RandomCrop(args['input_size']), joint_transforms.RandomHorizontallyFlip() ]) val_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(short_size), joint_transforms.CenterCrop(args['input_size']) ]) input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage() ]) visualize = standard_transforms.ToTensor() train_set = mapillary.Mapillary('semantic', 'training', joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True, pin_memory=True) val_set = mapillary.Mapillary('semantic', 'validation', joint_transform=val_joint_transform, transform=input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False, pin_memory=True) criterion = CrossEntropyLoss2d(size_average=False).cuda() optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum']) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()).replace(':', '-') + '.txt'), 'w').write(str(args) + '\n\n') scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10) for epoch in range(curr_epoch, args['epoch_num'] + 1): train(train_loader, net, criterion, optimizer, epoch, args) val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize) scheduler.step(val_loss) torch.save(net.state_dict(), PATH)
def train_with_ignite(networks, dataset, data_dir, batch_size, img_size, epochs, lr, momentum, num_workers, optimizer, logger): from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.metrics import Loss from utils.metrics import MultiThresholdMeasures, Accuracy, IoU, F1score # device device = 'cuda' if torch.cuda.is_available() else 'cpu' # build model model = get_network(networks) # log model summary input_size = (3, img_size, img_size) summarize_model(model.to(device), input_size, logger, batch_size, device) # build loss loss = torch.nn.BCEWithLogitsLoss() # build optimizer and scheduler model_optimizer = get_optimizer(optimizer, model, lr, momentum) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model_optimizer) # transforms on both image and mask train_joint_transforms = jnt_trnsf.Compose([ jnt_trnsf.RandomCrop(img_size), jnt_trnsf.RandomRotate(5), jnt_trnsf.RandomHorizontallyFlip() ]) # transforms only on images train_image_transforms = std_trnsf.Compose([ std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05), std_trnsf.ToTensor(), std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) test_joint_transforms = jnt_trnsf.Compose([jnt_trnsf.Safe32Padding()]) test_image_transforms = std_trnsf.Compose([ std_trnsf.ToTensor(), std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # transforms only on mask mask_transforms = std_trnsf.Compose([std_trnsf.ToTensor()]) # build train / test loader train_loader = get_loader(dataset=dataset, data_dir=data_dir, train=True, joint_transforms=train_joint_transforms, image_transforms=train_image_transforms, mask_transforms=mask_transforms, batch_size=batch_size, shuffle=False, num_workers=num_workers) test_loader = get_loader(dataset=dataset, data_dir=data_dir, train=False, joint_transforms=test_joint_transforms, image_transforms=test_image_transforms, mask_transforms=mask_transforms, batch_size=1, shuffle=False, num_workers=num_workers) # build trainer / evaluator with ignite trainer = create_supervised_trainer(model, model_optimizer, loss, device=device) measure = MultiThresholdMeasures() evaluator = create_supervised_evaluator(model, metrics={ '': measure, 'pix-acc': Accuracy(measure), 'iou': IoU(measure), 'loss': Loss(loss), 'f1': F1score(measure), }, device=device) # initialize state variable for checkpoint state = update_state(model.state_dict(), 0, 0, 0, 0, 0) # make ckpt path ckpt_root = './ckpt/' filename = '{network}_{optimizer}_lr_{lr}_epoch_{epoch}.pth' ckpt_path = os.path.join(ckpt_root, filename) # execution after every training iteration @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(trainer): num_iter = (trainer.state.iteration - 1) % len(train_loader) + 1 if num_iter % 20 == 0: logger.info("Epoch[{}] Iter[{:03d}] Loss: {:.2f}".format( trainer.state.epoch, num_iter, trainer.state.output)) # execution after every training epoch @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(trainer): # evaluate on training set evaluator.run(train_loader) metrics = evaluator.state.metrics logger.info( "Training Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n" .format(trainer.state.epoch, metrics['loss'], str(metrics['pix-acc']), str(metrics['iou']), str(metrics['f1']))) # update state update_state(weight=model.state_dict(), train_loss=metrics['loss'], val_loss=state['val_loss'], val_pix_acc=state['val_pix_acc'], val_iou=state['val_iou'], val_f1=state['val_f1']) # execution after every epoch @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): # evaluate test(validation) set evaluator.run(test_loader) metrics = evaluator.state.metrics logger.info( "Validation Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n" .format(trainer.state.epoch, metrics['loss'], str(metrics['pix-acc']), str(metrics['iou']), str(metrics['f1']))) # update scheduler lr_scheduler.step(metrics['loss']) # update and save state update_state(weight=model.state_dict(), train_loss=state['train_loss'], val_loss=metrics['loss'], val_pix_acc=metrics['pix-acc'], val_iou=metrics['iou'], val_f1=metrics['f1']) path = ckpt_path.format(network=networks, optimizer=optimizer, lr=lr, epoch=trainer.state.epoch) save_ckpt_file(path, state) trainer.run(train_loader, max_epochs=epochs)
} # Paths to trained models & epoch counts DUCHDC_trainedModelPath = './ducModelFinal.pth' FCN8_trainedModelPath = './fcnModelFinal.pth' Unet_trainedModelPath = './unetModelFinal.pth' # Transforms mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) short_size = int(min(args['input_size']) / 0.875) joint_transform = joint_transforms.Compose([ joint_transforms.Scale(short_size), joint_transforms.RandomCrop(args['input_size']), joint_transforms.RandomHorizontallyFlip() ]) input_transform = standard_transforms.Compose( [standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std)]) target_transform = extended_transforms.MaskToTensor() restore_transform = standard_transforms.Compose([ extended_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage() ]) visualize = standard_transforms.ToTensor()
def main(): net = PSPNet(num_classes=voc.num_classes).cuda() if len(args['snapshot']) == 0: # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth'))) curr_epoch = 1 args['best_record'] = { 'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0 } else: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = { 'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]), 'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]), 'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11]) } net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(args['longer_size']), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip() ]) sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], voc.ignore_label) train_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() visualize = standard_transforms.Compose([ standard_transforms.Resize(args['val_img_display_size']), standard_transforms.ToTensor() ]) train_set = voc.VOC('train', joint_transform=train_joint_transform, sliding_crop=sliding_crop, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=1, shuffle=True, drop_last=True) val_set = voc.VOC('val', transform=val_input_transform, sliding_crop=sliding_crop, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=1, shuffle=False, drop_last=True) criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda() optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') train(train_loader, net, criterion, optimizer, curr_epoch, args, val_loader, visualize)
def main(): net = PSPNet(19) net.load_pretrained_model( model_path='./Caffe-PSPNet/pspnet101_cityscapes.caffemodel') for param in net.parameters(): param.requires_grad = False net.cbr_final = conv2DBatchNormRelu(4096, 128, 3, 1, 1, False) net.dropout = nn.Dropout2d(p=0.1, inplace=True) net.classification = nn.Conv2d(128, kitti_binary.num_classes, 1, 1, 0) # Find total parameters and trainable parameters total_params = sum(p.numel() for p in net.parameters()) print(f'{total_params:,} total parameters.') total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad) print(f'{total_trainable_params:,} training parameters.') if len(args['snapshot']) == 0: # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth'))) args['best_record'] = { 'epoch': 0, 'iter': 0, 'val_loss': 1e10, 'accu': 0 } else: print('training resumes from ' + args['snapshot']) net.load_state_dict( torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = { 'epoch': int(split_snapshot[1]), 'iter': int(split_snapshot[3]), 'val_loss': float(split_snapshot[5]), 'accu': float(split_snapshot[7]) } net.cuda(args['gpu']).train() mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0]) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(args['longer_size']), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip() ]) train_input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ extended_transforms.FlipChannels(), standard_transforms.ToTensor(), standard_transforms.Lambda(lambda x: x.mul_(255)), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() train_set = kitti_binary.KITTI(mode='train', joint_transform=train_joint_transform, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True) val_set = kitti_binary.KITTI(mode='val', transform=val_input_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False) criterion = nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 1.05)).cuda( args['gpu']) optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['lr'], 'weight_decay': args['weight_decay'] }], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open( os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n') train(train_loader, net, criterion, optimizer, args, val_loader)