# model.eval() # print(model) # output = model(input) # print('lightSeg', output.size()) # summary(model, (3, 512, 512)) # cpu input = torch.Tensor(1, 3, 512, 512) model = lightSeg(backbone='resnet101', n_classes=3, pretrained=False) model.eval() print(model) output = model(input) print('lightSeg', output.size()) summary(model, (3, 512, 512), device='cpu') total_paramters = netParams(model) print("the number of parameters: %d ==> %.2f M" % (total_paramters, (total_paramters / 1e6))) # batch_size = 1 # n_classes = 12 # img_height, img_width = 360, 480 # # img_height, img_width = 1024, 512 # model = lightSeg(n_classes=n_classes, pretrained=False, backbone='resnet18') # # x = Variable(torch.randn(batch_size, 3, img_height, img_width)) # y = Variable(torch.LongTensor(np.ones((batch_size, img_height, img_width), dtype=np.int))) # # print(x.shape) # start = time.time() # pred = model(x) # end = time.time() # print(end-start)
def train_model(args): """ args: args: global arguments """ h, w = map(int, args.input_size.split(',')) input_size = (h, w) print("=====> input size:{}".format(input_size)) print(args) if args.cuda: print("=====> use gpu id: '{}'".format(args.gpus)) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not torch.cuda.is_available(): raise Exception("No GPU found or Wrong gpu id, please run without --cuda") # set the seed setup_seed(GLOBAL_SEED) print("=====> set Global Seed: ", GLOBAL_SEED) cudnn.enabled = True # build the model and initialization model = build_model(args.model, num_classes=args.classes) init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, 1e-3, 0.1, mode='fan_in') print("=====> computing network parameters and FLOPs") total_paramters = netParams(model) print("the number of parameters: %d ==> %.2f M" % (total_paramters, (total_paramters / 1e6))) # load data and data augmentation datas, trainLoader, valLoader = build_dataset_train(args.dataset, input_size, args.batch_size, args.train_type, args.random_scale, args.random_mirror, args.num_workers) print('=====> Dataset statistics') print("data['classWeights']: ", datas['classWeights']) print('mean and std: ', datas['mean'], datas['std']) # define loss function, respectively weight = torch.from_numpy(datas['classWeights']) if args.dataset == 'camvid': criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label) elif args.dataset == 'cityscapes': min_kept = int(args.batch_size // len(args.gpus) * h * w // 16) criteria = ProbOhemCrossEntropy2d(use_weight=True, ignore_label=ignore_label, thresh=0.7, min_kept=min_kept) elif args.dataset == 'paris': criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label) # criteria = nn.CrossEntropyLoss(weight=weight) # min_kept = int(args.batch_size // len(args.gpus) * h * w // 16) # criteria = ProbOhemCrossEntropy2d(ignore_label=ignore_label, thresh=0.7, min_kept=min_kept, use_weight=False) elif args.dataset == 'austin': criteria = BinCrossEntropyLoss2d(weight=weight) elif args.dataset == 'road': criteria = BinCrossEntropyLoss2d(weight=weight) else: raise NotImplementedError( "This repository now supports two datasets: cityscapes and camvid, %s is not included" % args.dataset) if args.cuda: criteria = criteria.cuda() if torch.cuda.device_count() > 1: print("torch.cuda.device_count()=", torch.cuda.device_count()) args.gpu_nums = torch.cuda.device_count() model = nn.DataParallel(model).cuda() # multi-card data parallel else: args.gpu_nums = 1 print("single GPU for training") model = model.cuda() # 1-card data parallel args.savedir = (args.savedir + args.dataset + '/' + args.model + 'bs' + str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" + str(args.train_type) + '/' + str(date) +'/') if not os.path.exists(args.savedir): os.makedirs(args.savedir) start_epoch = 0 # continue training if args.resume: if os.path.isfile(args.resume): checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) # model.load_state_dict(convert_state_dict(checkpoint['model'])) print("=====> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) else: print("=====> no checkpoint found at '{}'".format(args.resume)) model.train() cudnn.benchmark = True logFileLoc = args.savedir + args.logFile if os.path.isfile(logFileLoc): logger = open(logFileLoc, 'a') else: logger = open(logFileLoc, 'w') logger.write("Parameters: %s Seed: %s\n %s\n" % (str(total_paramters/ 1e6), GLOBAL_SEED, args)) logger.write("\n%s\t\t%s\t\t%s\t\t%s\t%s\t%s" % ('Epoch', ' lr', ' Loss', ' Pa', ' Mpa', ' mIOU')) for i in range(args.classes): logger.write("\t%s" % ('Class'+str(i))) logger.flush() # define optimization criteria if args.dataset == 'camvid': optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=2e-4) elif args.dataset == 'cityscapes': optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), args.lr, momentum=0.9, weight_decay=1e-4) elif args.dataset == 'paris': optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), args.lr, momentum=0.9, weight_decay=1e-4) elif args.dataset == 'austin': optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), args.lr, momentum=0.9, weight_decay=1e-4) elif args.dataset == 'road': optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), args.lr, momentum=0.9, weight_decay=1e-4) lossTr_list = [] epoches = [] mIOU_val_list = [] max_miou = 0 miou = 0 print('***********************************************\n' '******* Begining traing *******\n' '***********************************************') for epoch in range(start_epoch, args.max_epochs): # training lossTr, lr = train(args, trainLoader, model, criteria, optimizer, epoch) lossTr_list.append(lossTr) # validation if (epoch % args.val_epochs == 0 and args.train_val == 'True') or epoch == (args.max_epochs - 1): epoches.append(epoch) miou, iou, fmiou, pa, mpa = val(args, valLoader, model) mIOU_val_list.append(miou) # record train information logger.write("\n %d\t\t%.6f\t%.5f\t\t%.4f\t%0.4f\t%0.4f" % (epoch, lr, lossTr, fmiou, pa, miou)) for i in range(len(iou)): logger.write("\t%0.4f" % (iou[i])) logger.flush() print("Epoch %d\tTrain Loss = %.4f\t mIOU(val) = %.4f\t lr= %.5f\n" % (epoch, lossTr, miou, lr)) else: # record train information logger.write("\n%d\t%.6f\t\t%.5f" % (epoch, lr, lossTr)) logger.flush() print("Epoch %d\tTrain Loss = %.4f\t lr= %.6f\n" % (epoch, lossTr, lr)) # save the model model_file_name = args.savedir + '/model_' + str(epoch) + '.pth' state = {"epoch": epoch, "model": model.state_dict()} if max_miou < miou and epoch >= args.max_epochs - 50: max_miou = miou torch.save(state, model_file_name) elif epoch % args.save_epochs == 0: torch.save(state, model_file_name) # draw plots for visualization if epoch % args.val_epochs == 0 or epoch == (args.max_epochs - 1): # Plot the figures per args.val_epochs epochs fig1, ax1 = plt.subplots(figsize=(11, 8)) ax1.plot(range(start_epoch, epoch + 1), lossTr_list) ax1.set_title("Average training loss vs epochs") ax1.set_xlabel("Epochs") ax1.set_ylabel("Current loss") plt.savefig(args.savedir + "loss_vs_epochs.png") plt.clf() fig2, ax2 = plt.subplots(figsize=(11, 8)) ax2.plot(epoches, mIOU_val_list, label="Val IoU") ax2.set_title("Average IoU vs epochs") ax2.set_xlabel("Epochs") ax2.set_ylabel("Current IoU") plt.legend(loc='lower right') plt.savefig(args.savedir + "iou_vs_epochs.png") plt.close('all') logger.close()