optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #select the optimizer exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) # create the train_dataset_loader and val_dataset_loader. cloud_data = CloudDataset(img_dir='data/images/', labels_dir='data/GTmaps/') trainer = Trainer('inference', optimizer, exp_lr_scheduler, net, cfig, './log') trainer.load_weights(trainer.find_last()) #trainer.load_weights('log/renset20190102T1348/model_renset_0046.pt') since = time.time() for x in range(0, 801, 5): images = cloud_data[x]['image'] gt_map = cloud_data[x]['gt_map'] mask = trainer.detect(images) mask = np.round(mask * 255) # images=cv2.cvtColor(images,cv2.COLOR_BGR2GRAY) # cv2.imwrite('result/{}_image.png'.format(x),images) # cv2.imwrite('result/{}gt_map.png'.format(x),gt_map) #cv2.imwrite('result/{}sigmoid.png'.format(x),mask) #fig.set_size_inches(600/100.0,600/100.0)#输出width*height像素 print(mask.shape) fig = plt.figure() fig.set_size_inches(600 / 100.0, 300 / 100.0) #输出width*height像素 plt.subplot(121) plt.xticks([]) #去掉横坐标值 plt.yticks([]) #去掉纵坐标值