Exemplo n.º 1
0
def test(args):
    from test_utils import ModelTest

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare Data
    test_dataset = Chime_Dataset('dt', args)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             pin_memory=True,
                             collate_fn=lambda x: Chime_Collate(x),
                             num_workers=args.num_workers)

    # Prepare model
    if args.model_type == 'BLSTM':
        model = BLSTMMaskEstimator()
        model_save_dir = os.path.join(args.data_dir, 'BLSTM_model')
        mkdir_p(model_save_dir)
    elif args.model_type == 'FW':
        model = SimpleFWMaskEstimator()
        model_save_dir = os.path.join(args.data_dir, 'FW_model')
        mkdir_p(model_save_dir)
    else:
        raise ValueError('Unknown model type. Possible are "BLSTM" and "FW"')

    criterion = torch.nn.BCELoss()

    tester = ModelTest(model, test_loader, criterion, args, device)
    tester.test()
Exemplo n.º 2
0
def train(args):
    from train_utils import ModelTrainer

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare Data
    train_dataset = Chime_Dataset('tr', args)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              pin_memory=True,
                              collate_fn=lambda x: Chime_Collate(x),
                              num_workers=args.num_workers)

    # Prepare model
    if args.model_type == 'BLSTM':
        model = BLSTMMaskEstimator()
        model_save_dir = os.path.join(args.data_dir, 'BLSTM_model')
        mkdir_p(model_save_dir)
    elif args.model_type == 'FW':
        model = SimpleFWMaskEstimator()
        model_save_dir = os.path.join(args.data_dir, 'FW_model')
        mkdir_p(model_save_dir)
    else:
        raise ValueError('Unknown model type. Possible are "BLSTM" and "FW"')

    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=1e-4)

    trainer = ModelTrainer(model, train_loader, criterion, optimizer, args,
                           device)
    trainer.train(args.num_epochs)
Exemplo n.º 3
0
parser = argparse.ArgumentParser(description='NN GEV beamforming')
parser.add_argument('model', help='Trained model file')
parser.add_argument('model_type', help='Type of model (BLSTM or FW)')
parser.add_argument('--gpu',
                    '-g',
                    default=-1,
                    type=int,
                    help='GPU ID (negative value indicates CPU)')
parser.add_argument('data_directory', help='data experiment directory')
parser.add_argument('exNum', help='Experiment order')
args = parser.parse_args()

# Prepare model
if args.model_type == 'BLSTM':
    model = BLSTMMaskEstimator()
elif args.model_type == 'FW':
    model = SimpleFWMaskEstimator()
else:
    raise ValueError('Unknown model type. Possible are "BLSTM" and "FW"')

serializers.load_hdf5(args.model, model)
print("data type of 'model'", type(model))
if args.gpu >= 0:
    cuda.get_device(args.gpu).use()
    model.to_gpu()
xp = np if args.gpu < 0 else cuda.cupy

# def single_noise():
#     audio_data = get_audio_nochime('new_dataset/2m/2m_pub_new', ch_range=range(1, 9), fs=49000)
#     # audio_data = get_audio_nochime('new_dataset/new_audio/AUDIO_RECORDING', ch_range=range(1, 9), fs=49000)