コード例 #1
0
    def __init__(self, cfg):

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-4

        self.stft_module = STFTModule(cfg['stft_params'], self.device)
        self.train_data_num = cfg['train_data_num']
        self.valid_data_num = cfg['valid_data_num']
        self.sample_len = cfg['sample_len']
        self.epoch_num = cfg['epoch_num']
        self.train_batch_size = cfg['train_batch_size']
        self.valid_batch_size = cfg['valid_batch_size']

        self.train_dataset = DSD100Dataset(data_num=self.train_data_num,
                                           sample_len=self.sample_len,
                                           folder_type='train')
        self.valid_dataset = DSD100Dataset(data_num=self.valid_data_num,
                                           sample_len=self.sample_len,
                                           folder_type='validation')

        self.train_data_loader = FastDataLoader(
            self.train_dataset, batch_size=self.train_batch_size, shuffle=True)
        self.valid_data_loader = FastDataLoader(
            self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True)
        self.model = UNet().to(self.device)
        self.criterion = MSE()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.save_path = 'results/model/dsd_unet_config_1/'
コード例 #2
0
    def __init__(self, cfg):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-4
        self.eval_path = cfg['eval_path']

        self.model = UNet().to(self.device)
        self.model.eval()
        self.model.load_state_dict(
            torch.load(self.eval_path, map_location=self.device))

        self.stft_module = STFTModule(
            cfg['stft_params'],
            self.device,
        )
        self.test_data_num = cfg['test_data_num']
        self.test_batch_size = cfg['test_batch_size']
        self.sample_len = cfg['sample_len']
        self.test_dataset = DSD100Dataset(data_num=self.test_data_num,
                                          sample_len=self.sample_len,
                                          folder_type='test',
                                          shuffle=False)
        self.test_data_loader = FastDataLoader(self.test_dataset,
                                               batch_size=self.test_batch_size,
                                               shuffle=False)

        self.sdr_list = np.array([])
        self.sar_list = np.array([])
        self.sir_list = np.array([])
コード例 #3
0
def get_model(model_name, num_classes, weight=None):
    print('====> load {}-classes segmentation model: {}'.format(
        num_classes, model_name))
    model = None
    if model_name == 'fcn8':
        model = FCN8s(num_classes=num_classes)
    elif model_name == 'fcn16':
        model = FCN16VGG(num_classes=num_classes, pretrained=False)
    elif model_name == 'fcn32':
        model = FCN32VGG(num_classes=num_classes, pretrained=False)
    elif model_name == 'unet':
        model = UNet(num_classes=num_classes)
    elif model_name == 'duc':
        model = ResNetDUC(num_classes=num_classes)
    elif model_name == 'duc_hdc':
        model = ResNetDUCHDC(num_classes=num_classes)
    elif model_name == 'gcn':
        model = GCN(num_classes=num_classes, input_size=512)
    elif model_name == 'pspnet':
        model = PSPNet(num_classes=num_classes)
    elif model_name == 'segnet':
        model = SegNet(num_classes=num_classes)
    if weight is not None:
        model.load_state_dict(torch.load(weight))
    return model
コード例 #4
0
ファイル: u_net_setting.py プロジェクト: jinming0912/RGBxD
    def __init__(self, setting_path=None):
        self.name = 'UNet'

        self.gpu_id = '0'
        os.environ['CUDA_VISIBLE_DEVICES'] = self.gpu_id

        self.enable_eager = True
        if not self.enable_eager:
            tf.compat.v1.disable_eager_execution()

        self.save_folder = os.path.join(os.environ['HOME'], "Summarys")

        self.batch_size = 3
        self.num_epochs = 2000
        self.validation_freq = 2

        self.dataset = NyuV2Fcn(
            batch_size=self.batch_size,
            img_size=[640, 480],
            base_dir=os.path.join(os.environ['HOME'], "Datasets/Downloads/NYU-Depth-V2/nyud"),
            min_scale_factor=0.5,
            max_scale_factor=2,
            scale_factor_step_size=0.25,
            with_aug_flip=True,
            with_aug_scale_crop=True,
            with_aug_gaus_noise=True)
        # self.dataset = SunRgbd(
        #     batch_size=self.batch_size,
        #     base_dir=os.path.join(os.environ['HOME'], "Datasets/sun_rgbd"),
        #     min_scale_factor=0.5,
        #     max_scale_factor=2,
        #     scale_factor_step_size=0.25,  # 0.25
        #     with_aug_flip=True,
        #     with_aug_scale_crop=True,
        #     with_aug_gaus_noise=True)

        self.loss_weights = [0.0] + [1.0] * (self.dataset.num_class - 1)
        self.metrics_weights = [0.0] + [1.0] * (self.dataset.num_class - 1)

        self.model = UNet(num_class=self.dataset.num_class,
                          fusion_mode='bgr_hha_gw',
                          # 'bgr_hha', 'bgr', 'hha', 'bgr_hha_gw', bgr_hha_xyz
                          name='UNet',
                          backbone='mobile_net_v2',
                          backbone_initializer='imagenet')

        decay_steps = self.num_epochs * self.dataset.get_trainval_size()
        self.lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=0.00015,
                                                                         power=0.9,
                                                                         decay_steps=decay_steps,
                                                                         end_learning_rate=0.00001,
                                                                         name='poly_lr')
        self.optimizer = tf.keras.optimizers.SGD(learning_rate=0.001,  # 0.001, self.lr_schedule
                                                 # decay=0.0005,
                                                 momentum=0.9)
        # self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

        if setting_path is not None:
            # ...
            return
コード例 #5
0
 def __init__(self, cfg):
     
     self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     self.dtype= torch.float32
     self.eps = 1e-4
     
     self.stft_module = STFTModule(cfg['stft_params'])
     self.inst_num = cfg['inst_num']
     self.train_data_num = cfg['train_data_num']
     self.valid_data_num = cfg['valid_data_num']
     self.sample_len = cfg['sample_len']
     self.epoch_num = cfg['epoch_num']
     self.train_batch_size = cfg['train_batch_size']
     self.valid_batch_size = cfg['valid_batch_size']
     
     self.train_dataset = SlakhDataset(inst_num=self.inst_num, data_num=self.train_data_num, sample_len=self.sample_len, folder_type='train')
     self.valid_dataset = SlakhDataset(inst_num=self.inst_num, data_num=self.valid_data_num, sample_len=self.sample_len, folder_type='validation')
     
     self.train_data_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True)
     self.valid_data_loader = torch.utils.data.DataLoader(self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True)
     self.model = UNet().to(self.device)
     self.criterion = MSE()
     self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
     self.save_path = 'results/model/inst_num_{0}/'.format(self.inst_num)
コード例 #6
0
                                           batch_size=opt.batch_size,
                                           shuffle=True,
                                           num_workers=opt.num_workers)


def weights_init(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif class_name.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


net = UNet(opt.input_nc, opt.output_nc)

if opt.net != '':
    net.load_state_dict(torch.load(opt.netG))
else:
    net.apply(weights_init)
if opt.cuda:
    net.cuda()
if opt.num_GPU > 1:
    net = nn.DataParallel(net)

###########   LOSS & OPTIMIZER   ##########
# criterion = nn.BCELoss()
criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=255)
optimizer = torch.optim.Adam(net.parameters(),
                             lr=opt.lr,
コード例 #7
0
class UNetRunner():
    def __init__(self, cfg):
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype= torch.float32
        self.eps = 1e-4
        
        self.stft_module = STFTModule(cfg['stft_params'])
        self.inst_num = cfg['inst_num']
        self.train_data_num = cfg['train_data_num']
        self.valid_data_num = cfg['valid_data_num']
        self.sample_len = cfg['sample_len']
        self.epoch_num = cfg['epoch_num']
        self.train_batch_size = cfg['train_batch_size']
        self.valid_batch_size = cfg['valid_batch_size']
        
        self.train_dataset = SlakhDataset(inst_num=self.inst_num, data_num=self.train_data_num, sample_len=self.sample_len, folder_type='train')
        self.valid_dataset = SlakhDataset(inst_num=self.inst_num, data_num=self.valid_data_num, sample_len=self.sample_len, folder_type='validation')
        
        self.train_data_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True)
        self.valid_data_loader = torch.utils.data.DataLoader(self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True)
        self.model = UNet().to(self.device)
        self.criterion = MSE()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.save_path = 'results/model/inst_num_{0}/'.format(self.inst_num)
        
    def _preprocess(self, mixture, true_sources):
        mix_spec = self.stft_module.stft(mixture, pad=True)
        mix_phase = mix_spec[:,:,1]
        mix_amp_spec = taF.complex_norm(mix_spec)
        mix_amp_spec = mix_amp_spec[:,1:,:]
        mix_mag_spec = torch.log10(mix_amp_spec + self.eps)
        mix_mag_spec = mix_mag_spec[:,1:,:]
        
        true_sources_spec = self.stft_module.stft_3D(true_sources, pad=True)
        true_sources_amp_spec = taF.complex_norm(true_sources_spec)
        true_sources_amp_spec = true_sources_amp_spec[:,:,1:,:]
        
        true_res = mixture.unsqueeze(1).repeat(1, self.inst_num, 1) - true_sources
        true_res_spec = self.stft_module.stft_3D(true_res, pad=True)
        true_res_amp_spec = taF.complex_norm(true_res_spec)
        return mix_mag_spec, true_sources_amp_spec, true_res_amp_spec, mix_phase, mix_amp_spec
        
    def _postporcess(self, est_sources):
        pass
        
    def _run(self, model, criterion, data_loader, batch_size, mode=None):
        running_loss = 0
        for i, (mixture, sources, _) in enumerate(data_loader):
            mixture = mixture.to(self.dtype).to(self.device)
            sources = sources.to(self.dtype).to(self.device)
            mix_mag_spec, true_sources_amp_spec, true_res_amp_spec, _, mix_amp_spec = self._preprocess(mixture, sources)
            
            model.zero_grad()
            est_mask = model(mix_mag_spec.unsqueeze(1))
            est_sources = mix_amp_spec.unsqueeze(1) * est_mask
            
            if mode == 'train' or mode == 'validation':
                loss = 10 * criterion(est_sources, true_sources_amp_spec, self.inst_num)
                running_loss += loss.data
                if mode == 'train':
                    loss.backward()
                    self.optimizer.step()
            
        return (running_loss / (i+1)), est_sources, est_mask, mix_amp_spec
    
    def train(self):
        train_loss = np.array([])
        valid_loss = np.array([])
        print("start train")
        for epoch in range(self.epoch_num):
            # train
            print('epoch{0}'.format(epoch))
            start = time.time()
            self.model.train()
            tmp_train_loss, _, _, _ = self._run(self.model, self.criterion, self.train_data_loader, self.train_batch_size, mode='train')
            train_loss = np.append(train_loss, tmp_train_loss.cpu().clone().numpy())
            # validation
            self.model.eval()
            with torch.no_grad():
               tmp_valid_loss, est_source, est_mask, mix_amp_spec = self._run(self.model, self.criterion, self.valid_data_loader, self.valid_batch_size, mode='validation')
               valid_loss = np.append(valid_loss, tmp_valid_loss.cpu().clone().numpy())
                 
            if (epoch + 1) % 10 == 0:
                torch.save(self.model.state_dict(), self.save_path + 'u_net{0}.ckpt'.format(epoch + 1))
            
            end = time.time()
            print('----excute time: {0}'.format(end - start))
            show_TF_domein_result(valid_loss, mix_amp_spec[0,:,:], est_mask[0,0,:,:], est_source[0,0,:,:])
コード例 #8
0
class UNetRunner():
    def __init__(self, cfg):

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-4

        self.stft_module = STFTModule(cfg['stft_params'], self.device)
        self.train_data_num = cfg['train_data_num']
        self.valid_data_num = cfg['valid_data_num']
        self.sample_len = cfg['sample_len']
        self.epoch_num = cfg['epoch_num']
        self.train_batch_size = cfg['train_batch_size']
        self.valid_batch_size = cfg['valid_batch_size']

        self.train_dataset = DSD100Dataset(data_num=self.train_data_num,
                                           sample_len=self.sample_len,
                                           folder_type='train')
        self.valid_dataset = DSD100Dataset(data_num=self.valid_data_num,
                                           sample_len=self.sample_len,
                                           folder_type='validation')

        self.train_data_loader = FastDataLoader(
            self.train_dataset, batch_size=self.train_batch_size, shuffle=True)
        self.valid_data_loader = FastDataLoader(
            self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True)
        self.model = UNet().to(self.device)
        self.criterion = MSE()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.save_path = 'results/model/dsd_unet_config_1/'

    def _preprocess(self, mixture, true):
        with torch.no_grad():
            mix_spec = self.stft_module.stft(mixture, pad=True)

            mix_amp_spec = taF.complex_norm(mix_spec)
            mix_amp_spec = mix_amp_spec[:, 1:, :]
            mix_mag_spec = torch.log10(mix_amp_spec + self.eps)

            true_spec = self.stft_module.stft(true, pad=True)
            true_amp_spec = taF.complex_norm(true_spec)
            true_amp_spec = true_amp_spec[:, 1:, :]

        return mix_mag_spec, true_amp_spec, mix_amp_spec

    def _postporcess(self, est_sources):
        pass

    def _run(self, mode=None, data_loader=None):
        running_loss = 0
        for i, (mixture, _, _, _, vocals) in enumerate(data_loader):
            mixture = mixture.to(self.dtype).to(self.device)
            true = vocals.to(self.dtype).to(self.device)
            mix_mag_spec, true_amp_spec, mix_amp_spec = self._preprocess(
                mixture, true)
            self.model.zero_grad()
            est_mask = self.model(mix_mag_spec.unsqueeze(1))
            est_source = mix_amp_spec.unsqueeze(1) * est_mask

            if mode == 'train' or mode == 'validation':
                loss = 10 * self.criterion(est_source, true_amp_spec)
                running_loss += loss.data
                if mode == 'train':
                    loss.backward()
                    self.optimizer.step()

        return (running_loss /
                (i + 1)), est_source, est_mask, mix_amp_spec, true_amp_spec

    def train(self):
        train_loss = np.array([])
        print("start train")
        for epoch in range(self.epoch_num):
            # train
            print('epoch{0}'.format(epoch))
            start = time.time()
            self.model.train()
            tmp_train_loss, est_source, est_mask, mix_amp_spec, true_amp_spec = self._run(
                mode='train', data_loader=self.train_data_loader)
            train_loss = np.append(train_loss,
                                   tmp_train_loss.cpu().clone().numpy())

            if (epoch + 1) % 10 == 0:
                plot_time = time.time()
                show_TF_domein_result(train_loss, mix_amp_spec[0, :, :],
                                      est_mask[0,
                                               0, :, :], est_source[0,
                                                                    0, :, :],
                                      true_amp_spec[0, :, :])
                print('plot_time:', time.time() - plot_time)
                torch.save(self.model.state_dict(),
                           self.save_path + 'u_net{0}.ckpt'.format(epoch + 1))

            end = time.time()
            print('----excute time: {0}'.format(end - start))
コード例 #9
0
class Transformer(object):
    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = size
        self.interpolation = interpolation
        self.toTensor = transforms.ToTensor()

    def __call__(self, img_):
        img_ = img_.resize(self.size, self.interpolation)
        img_ = self.toTensor(img_)
        img_.sub_(0.5).div_(0.5)
        return img_


#model = Segnet(3,3)
#model_path = './checkpoint/Segnet/model/netG_final.pth'
model = UNet(3, 3)
model_path = './checkpoint/Unet/model/netG_final.pth'
model.load_state_dict(torch.load(model_path, map_location='cpu'))

test_image_path = 'C:/Users/1/Desktop/11.png'
test_image = Image.open(test_image_path).convert('RGB')
print('Operating...')
transformer = Transformer((256, 256))
img = transformer(test_image)
img = img.unsqueeze(0)
img = Variable(img)
label_image = model(img)
label_image = label_image.squeeze(0)
show = ToPILImage()
a = show((label_image + 1) / 2)  ##转换的时候,会自动从0-1转换成0-256,所以0.5会变成127
print(a.getpixel((100, 100)))
コード例 #10
0
class UNetTester():
    def __init__(self, cfg):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-4
        self.eval_path = cfg['eval_path']

        self.model = UNet().to(self.device)
        self.model.eval()
        self.model.load_state_dict(
            torch.load(self.eval_path, map_location=self.device))

        self.stft_module = STFTModule(
            cfg['stft_params'],
            self.device,
        )
        self.test_data_num = cfg['test_data_num']
        self.test_batch_size = cfg['test_batch_size']
        self.sample_len = cfg['sample_len']
        self.test_dataset = DSD100Dataset(data_num=self.test_data_num,
                                          sample_len=self.sample_len,
                                          folder_type='test',
                                          shuffle=False)
        self.test_data_loader = FastDataLoader(self.test_dataset,
                                               batch_size=self.test_batch_size,
                                               shuffle=False)

        self.sdr_list = np.array([])
        self.sar_list = np.array([])
        self.sir_list = np.array([])

    def _preprocess(self, mixture, true):
        mix_spec = self.stft_module.stft(mixture, pad=True)
        mix_amp_spec = taF.complex_norm(mix_spec)
        mix_amp_spec = mix_amp_spec[:, 1:, :]
        mix_mag_spec = torch.log10(mix_amp_spec + self.eps)
        return mix_mag_spec, mix_spec

    def _postprocess(self, x):
        x = x.squeeze(1)
        batch_size, f_size, t_size = x.shape
        pad_x = torch.zeros((batch_size, f_size + 1, t_size),
                            dtype=self.dtype,
                            device=self.device)
        pad_x[:, 1:, :] = x[:, :, :]
        return pad_x

    def test(self, mode='test'):
        with torch.no_grad():
            for i, (mixture, _, _, _,
                    vocals) in enumerate(self.test_data_loader):
                start = time.time()
                mixture = mixture.squeeze(0).to(self.dtype).to(self.device)
                true = vocals.squeeze(0).to(self.dtype).to(self.device)

                mix_mag_spec, mix_spec = self._preprocess(mixture, true)
                est_mask = self.model(mix_mag_spec.unsqueeze(1))
                est_mask = self._postprocess(est_mask)
                est_source = mix_spec * est_mask[..., None]
                est_wave = self.stft_module.istft(est_source)

                est_wave = est_wave.flatten()
                mixture = mixture.flatten()
                true = true.flatten()
                true_accompany = mixture - true
                est_accompany = mixture - est_wave
                sdr, sir, sar = mss_evals(est_wave, est_accompany, true,
                                          true_accompany)
                self.sdr_list = np.append(self.sdr_list, sdr)
                self.sar_list = np.append(self.sar_list, sar)
                self.sir_list = np.append(self.sir_list, sir)
                print('test time:', time.time() - start)

            print('sdr mean:', np.mean(self.sdr_list))
            print('sir mean:', np.mean(self.sir_list))
            print('sar mean:', np.mean(self.sar_list))
コード例 #11
0
train_loader = torch.utils.data.DataLoader(dataset=train_datatset_, batch_size=opt.batch_size, shuffle=True,
                                           num_workers=opt.num_workers)


###########   MODEL   ###########
def weights_init(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif class_name.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


net = UNet(opt.input_nc, opt.output_nc)
writer = SummaryWriter(log_dir='train/UNet_run')

if opt.net != '':
    net.load_state_dict(torch.load(opt.netG))
else:
    net.apply(weights_init)
if opt.cuda:
    net.cuda()
if opt.num_GPU > 1:
    net=nn.DaraParallel(net)

# print(net)

###########   LOSS & OPTIMIZER   ##########
criterion = nn.BCELoss()