Exemple #1
0
def train_auto(train,
               fun,
               transform,
               testdir,
               outdir,
               num_epochs=30,
               model="1.pkl",
               scale_factor=0.3,
               load=False,
               skip_train=False,
               skip_sep=False):
    """
    Trains a network built with \"fun\" with the data generated with \"train\"
    and then separates the files in \"testdir\",writing the result in \"outdir\"

    Parameters
    ----------
    train : Callable, e.g. LargeDataset object
        The callable which generates training data for the network: inputs, target = train()
    fun : lasagne network object, Theano tensor
        The network to be trained  
    transform : transformFFT object
        The Transform object which was used to compute the features (see compute_features.py)
    testdir : string, optional
        The directory where the files to be separated are located
    outdir : string, optional
        The directory where to write the separated files
    num_epochs : int, optional
        The number the epochs to train for (one epoch is when all examples in the dataset are seen by the network)
    model : string, optional
        The path where to save the trained model (theano tensor containing the network) 
    scale_factor : float, optional
        Scale the magnitude of the files to be separated with this factor
    Yields
    ------
    losser : list
        The losses for each epoch, stored in a list
    """

    logging.info("Building Autoencoder")
    input_var2 = T.tensor4('inputs')
    target_var2 = T.tensor4('targets')
    rand_num = T.tensor4('rand_num')

    eps = 1e-8
    alpha = 0.001
    beta = 0.01
    beta_voc = 0.03

    network2 = fun(input_var=input_var2,
                   batch_size=train.batch_size,
                   time_context=train.time_context,
                   feat_size=train.input_size)

    if load:
        params = load_model(model)
        lasagne.layers.set_all_param_values(network2, params)

    prediction2 = lasagne.layers.get_output(network2, deterministic=True)

    rand_num = np.random.uniform(size=(train.batch_size, 1, train.time_context,
                                       train.input_size))

    voc = prediction2[:, 0:1, :, :] + eps * rand_num
    bas = prediction2[:, 1:2, :, :] + eps * rand_num
    dru = prediction2[:, 2:3, :, :] + eps * rand_num
    oth = prediction2[:, 3:4, :, :] + eps * rand_num

    mask1 = voc / (voc + bas + dru + oth)
    mask2 = bas / (voc + bas + dru + oth)
    mask3 = dru / (voc + bas + dru + oth)
    mask4 = oth / (voc + bas + dru + oth)

    vocals = mask1 * input_var2
    bass = mask2 * input_var2
    drums = mask3 * input_var2
    others = mask4 * input_var2

    train_loss_recon_vocals = lasagne.objectives.squared_error(
        vocals, target_var2[:, 0:1, :, :])
    alpha_component = alpha * lasagne.objectives.squared_error(
        vocals, target_var2[:, 1:2, :, :])
    alpha_component += alpha * lasagne.objectives.squared_error(
        vocals, target_var2[:, 2:3, :, :])
    train_loss_recon_neg_voc = beta_voc * lasagne.objectives.squared_error(
        vocals, target_var2[:, 3:4, :, :])

    train_loss_recon_bass = lasagne.objectives.squared_error(
        bass, target_var2[:, 1:2, :, :])
    alpha_component += alpha * lasagne.objectives.squared_error(
        bass, target_var2[:, 0:1, :, :])
    alpha_component += alpha * lasagne.objectives.squared_error(
        bass, target_var2[:, 2:3, :, :])
    train_loss_recon_neg = beta * lasagne.objectives.squared_error(
        bass, target_var2[:, 3:4, :, :])

    train_loss_recon_drums = lasagne.objectives.squared_error(
        drums, target_var2[:, 2:3, :, :])
    alpha_component += alpha * lasagne.objectives.squared_error(
        drums, target_var2[:, 0:1, :, :])
    alpha_component += alpha * lasagne.objectives.squared_error(
        drums, target_var2[:, 1:2, :, :])
    train_loss_recon_neg += beta * lasagne.objectives.squared_error(
        drums, target_var2[:, 3:4, :, :])

    vocals_error = train_loss_recon_vocals.sum()
    drums_error = train_loss_recon_drums.sum()
    bass_error = train_loss_recon_bass.sum()
    negative_error = train_loss_recon_neg.sum()
    negative_error_voc = train_loss_recon_neg_voc.sum()
    alpha_component = alpha_component.sum()

    loss = abs(vocals_error + drums_error + bass_error - negative_error -
               alpha_component - negative_error_voc)

    params1 = lasagne.layers.get_all_params(network2, trainable=True)

    updates = lasagne.updates.adadelta(loss, params1)

    # val_updates=lasagne.updates.nesterov_momentum(loss1, params1, learning_rate=0.00001, momentum=0.7)

    train_fn = theano.function([input_var2, target_var2],
                               loss,
                               updates=updates,
                               allow_input_downcast=True)

    train_fn1 = theano.function([input_var2, target_var2], [
        vocals_error, bass_error, drums_error, negative_error, alpha_component,
        negative_error_voc
    ],
                                allow_input_downcast=True)

    predict_function2 = theano.function([input_var2],
                                        [vocals, bass, drums, others],
                                        allow_input_downcast=True)

    losser = []
    loss2 = []

    if not skip_train:

        logging.info("Training...")
        for epoch in range(num_epochs):

            train_err = 0
            train_batches = 0
            vocals_err = 0
            drums_err = 0
            bass_err = 0
            negative_err = 0
            alpha_component = 0
            beta_voc = 0
            start_time = time.time()
            for batch in range(train.iteration_size):
                inputs, target = train()
                jump = inputs.shape[2]
                inputs = np.reshape(
                    inputs,
                    (inputs.shape[0], 1, inputs.shape[1], inputs.shape[2]))
                targets = np.ndarray(shape=(inputs.shape[0], 4,
                                            inputs.shape[2], inputs.shape[3]))
                #import pdb;pdb.set_trace()
                targets[:, 0, :, :] = target[:, :, :jump]
                targets[:, 1, :, :] = target[:, :, jump:jump * 2]
                targets[:, 2, :, :] = target[:, :, jump * 2:jump * 3]
                targets[:, 3, :, :] = target[:, :, jump * 3:jump * 4]
                target = None

                train_err += train_fn(inputs, targets)
                [
                    vocals_erre, bass_erre, drums_erre, negative_erre, alpha,
                    betae_voc
                ] = train_fn1(inputs, targets)
                vocals_err += vocals_erre
                bass_err += bass_erre
                drums_err += drums_erre
                negative_err += negative_erre
                beta_voc += betae_voc
                alpha_component += alpha
                train_batches += 1

            print("Epoch {} of {} took {:.3f}s".format(
                epoch + 1, num_epochs,
                time.time() - start_time))
            print("  training loss:\t\t{:.6f}".format(train_err /
                                                      train_batches))
            losser.append(train_err / train_batches)
            print("  training loss for vocals:\t\t{:.6f}".format(
                vocals_err / train_batches))
            print("  training loss for bass:\t\t{:.6f}".format(bass_err /
                                                               train_batches))
            print("  training loss for drums:\t\t{:.6f}".format(drums_err /
                                                                train_batches))
            print("  Beta component:\t\t{:.6f}".format(negative_err /
                                                       train_batches))
            print("  Beta component for voice:\t\t{:.6f}".format(
                beta_voc / train_batches))
            print("  alpha component:\t\t{:.6f}".format(alpha_component /
                                                        train_batches))
            losser.append(train_err / train_batches)
            save_model(model, network2)

    if not skip_sep:

        logging.info("Separating")
        source = ['vocals', 'bass', 'drums', 'other']
        dev_directory = os.listdir(os.path.join(testdir, "Dev"))
        test_directory = os.listdir(os.path.join(
            testdir, "Test"))  #we do not include the test dir
        dirlist = []
        dirlist.extend(dev_directory)
        dirlist.extend(test_directory)
        for f in sorted(dirlist):
            if not f.startswith('.'):
                if f in dev_directory:
                    song = os.path.join(testdir, "Dev", f, "mixture.wav")
                else:
                    song = os.path.join(testdir, "Test", f, "mixture.wav")
                audioObj, sampleRate, bitrate = util.readAudioScipy(song)

                assert sampleRate == 44100, "Sample rate needs to be 44100"

                audio = (audioObj[:, 0] + audioObj[:, 1]) / 2
                audioObj = None
                mag, ph = transform.compute_file(audio, phase=True)

                mag = scale_factor * mag.astype(np.float32)

                batches, nchunks = util.generate_overlapadd(
                    mag,
                    input_size=mag.shape[-1],
                    time_context=train.time_context,
                    overlap=train.overlap,
                    batch_size=train.batch_size,
                    sampleRate=sampleRate)
                output = []

                batch_no = 1
                for batch in batches:
                    batch_no += 1
                    start_time = time.time()
                    output.append(predict_function2(batch))

                output = np.array(output)
                mm = util.overlapadd_multi(output,
                                           batches,
                                           nchunks,
                                           overlap=train.overlap)

                #write audio files
                if f in dev_directory:
                    dirout = os.path.join(outdir, "Dev", f)
                else:
                    dirout = os.path.join(outdir, "Test", f)
                if not os.path.exists(dirout):
                    os.makedirs(dirout)
                for i in range(mm.shape[0]):
                    audio_out = transform.compute_inverse(
                        mm[i, :len(ph)] / scale_factor, ph)
                    if len(audio_out) > len(audio):
                        audio_out = audio_out[:len(audio)]
                    util.writeAudioScipy(
                        os.path.join(dirout, source[i] + '.wav'), audio_out,
                        sampleRate, bitrate)
                    audio_out = None
                audio = None

    return losser
Exemple #2
0
def train_auto(train,fun,transform,testdir,outdir,num_epochs=30,model="1.pkl",scale_factor=0.3,load=False,skip_train=False,skip_sep=False):
    """
    Trains a network built with \"fun\" with the data generated with \"train\"
    and then separates the files in \"testdir\",writing the result in \"outdir\"

    Parameters
    ----------
    train : Callable, e.g. LargeDataset object
        The callable which generates training data for the network: inputs, target = train()
    fun : lasagne network object, Theano tensor
        The network to be trained  
    transform : transformFFT object
        The Transform object which was used to compute the features (see compute_features.py)
    testdir : string, optional
        The directory where the files to be separated are located
    outdir : string, optional
        The directory where to write the separated files
    num_epochs : int, optional
        The number the epochs to train for (one epoch is when all examples in the dataset are seen by the network)
    model : string, optional
        The path where to save the trained model (theano tensor containing the network) 
    scale_factor : float, optional
        Scale the magnitude of the files to be separated with this factor
    Yields
    ------
    losser : list
        The losses for each epoch, stored in a list
    """

    logging.info("Building Autoencoder")
    input_var2 = T.tensor4('inputs')
    target_var2 = T.tensor4('targets')
    rand_num = T.tensor4('rand_num')
    
    eps=1e-8
    alpha=0.9
    beta_acc=0.005
    beta_voc=0.02

    network2 = fun(input_var=input_var2,batch_size=train.batch_size,time_context=train.time_context,feat_size=train.input_size)
    
    if load:
        params=load_model(model)
        lasagne.layers.set_all_param_values(network2,params)

    prediction2 = lasagne.layers.get_output(network2, deterministic=True)

    rand_num = np.random.uniform(size=(train.batch_size,1,train.time_context,train.input_size))

    voc=prediction2[:,0:1,:,:]+eps*rand_num
    acco=prediction2[:,1:2,:,:]+eps*rand_num

    mask1=voc/(voc+acco)
    mask2=acco/(voc+acco)

    vocals=mask1*input_var2[:,0:1,:,:]
    acc=mask2*input_var2[:,0:1,:,:]
    
    train_loss_recon_vocals = lasagne.objectives.squared_error(vocals,target_var2[:,0:1,:,:])
    train_loss_recon_acc = alpha * lasagne.objectives.squared_error(acc,target_var2[:,1:2,:,:])    
    train_loss_recon_neg_voc = beta_voc * lasagne.objectives.squared_error(vocals,target_var2[:,1:2,:,:])
    train_loss_recon_neg_acc = beta_acc * lasagne.objectives.squared_error(acc,target_var2[:,0:1,:,:])

    vocals_error=train_loss_recon_vocals.sum()  
    acc_error=train_loss_recon_acc.sum()  
    negative_error_voc=train_loss_recon_neg_voc.sum()
    negative_error_acc=train_loss_recon_neg_acc.sum()
    
    loss=abs(vocals_error+acc_error-negative_error_voc)

    params1 = lasagne.layers.get_all_params(network2, trainable=True)

    updates = lasagne.updates.adadelta(loss, params1)

    train_fn = theano.function([input_var2,target_var2], loss, updates=updates,allow_input_downcast=True)

    train_fn1 = theano.function([input_var2,target_var2], [vocals_error,acc_error,negative_error_voc,negative_error_acc], allow_input_downcast=True)

    predict_function2=theano.function([input_var2],[vocals,acc],allow_input_downcast=True)
    predict_function3=theano.function([input_var2],[prediction2[:,0:1,:,:],prediction2[:,1:2,:,:]],allow_input_downcast=True)

    losser=[]
    loss2=[]

    if not skip_train:

        logging.info("Training...")
        for epoch in range(num_epochs):

            train_err = 0
            train_batches = 0
            vocals_err=0
            acc_err=0        
            beta_voc=0
            beta_acc=0
            start_time = time.time()
            for batch in range(train.iteration_size): 
                inputs, target = train()
                
                jump = inputs.shape[2]
                targets=np.ndarray(shape=(inputs.shape[0],2,inputs.shape[1],inputs.shape[2]))
                inputs=np.reshape(inputs,(inputs.shape[0],1,inputs.shape[1],inputs.shape[2]))          

                targets[:,0,:,:]=target[:,:,:jump]
                targets[:,1,:,:]=target[:,:,jump:jump*2]         
                target=None
        
                train_err+=train_fn(inputs,targets)
                [vocals_erre,acc_erre,betae_voc,betae_acc]=train_fn1(inputs,targets)
                vocals_err += vocals_erre
                acc_err += acc_erre           
                beta_voc+= betae_voc
                beta_acc+= betae_acc
                train_batches += 1
            
            logging.info("Epoch {} of {} took {:.3f}s".format(
                epoch + 1, num_epochs, time.time() - start_time))
            logging.info("  training loss:\t\t{:.6f}".format(train_err/train_batches))
            logging.info("  training loss for vocals:\t\t{:.6f}".format(vocals_err/train_batches))
            logging.info("  training loss for acc:\t\t{:.6f}".format(acc_err/train_batches))
            logging.info("  Beta component for voice:\t\t{:.6f}".format(beta_voc/train_batches))
            logging.info("  Beta component for acc:\t\t{:.6f}".format(beta_acc/train_batches))
            losser.append(train_err / train_batches)
            save_model(model,network2)

    if not skip_sep:

        logging.info("Separating")
        for f in os.listdir(testdir):
            if f.endswith(".wav"):
                audioObj, sampleRate, bitrate = util.readAudioScipy(os.path.join(testdir,f))
                
                assert sampleRate == 44100,"Sample rate needs to be 44100"

                audio = audioObj[:,0] + audioObj[:,1]
                audioObj = None
                mag,ph=transform.compute_file(audio,phase=True)
         
                mag=scale_factor*mag.astype(np.float32)

                batches,nchunks = util.generate_overlapadd(mag,input_size=mag.shape[-1],time_context=train.time_context,overlap=train.overlap,batch_size=train.batch_size,sampleRate=sampleRate)
                output=[]

                batch_no=1
                for batch in batches:
                    batch_no+=1
                    start_time=time.time()
                    output.append(predict_function2(batch))

                output=np.array(output)
                bmag,mm=util.overlapadd(output,batches,nchunks,overlap=train.overlap)
                
                audio_out=transform.compute_inverse(bmag[:len(ph)]/scale_factor,ph)
                if len(audio_out)>len(audio):
                    audio_out=audio_out[:len(audio)]
                audio_out=essentia.array(audio_out)
                audio_out2= transform.compute_inverse(mm[:len(ph)]/scale_factor,ph) 
                if len(audio_out2)>len(audio):
                    audio_out2=audio_out2[:len(audio)]  
                audio_out2=essentia.array(audio_out2) 
                #write audio files
                util.writeAudioScipy(os.path.join(outdir,f.replace(".wav","-voice.wav")),audio_out,sampleRate,bitrate)
                util.writeAudioScipy(os.path.join(outdir,f.replace(".wav","-music.wav")),audio_out2,sampleRate,bitrate)
                audio_out=None 
                audio_out2=None   

    return losser  
Exemple #3
0
def train_auto(train,
               fun,
               transform,
               testdir,
               outdir,
               testfile_list,
               testdir1,
               outdir1,
               testfile_list1,
               num_epochs=30,
               model="1.pkl",
               scale_factor=0.3,
               load=False,
               skip_train=False,
               skip_sep=False):
    """
    Trains a network built with \"fun\" with the data generated with \"train\"
    and then separates the files in \"testdir\",writing the result in \"outdir\"

    Parameters
    ----------
    train : Callable, e.g. LargeDataset object
        The callable which generates training data for the network: inputs, target = train()
    fun : lasagne network object, Theano tensor
        The network to be trained
    transform : transformFFT object
        The Transform object which was used to compute the features (see compute_features.py)
    testdir : string, optional
        The directory where the files to be separated are located
    outdir : string, optional
        The directory where to write the separated files
    num_epochs : int, optional
        The number the epochs to train for (one epoch is when all examples in the dataset are seen by the network)
    model : string, optional
        The path where to save the trained model (theano tensor containing the network)
    scale_factor : float, optional
        Scale the magnitude of the files to be separated with this factor
    Yields
    ------
    losser : list
        The losses for each epoch, stored in a list
    """

    logging.info("Building Autoencoder")
    input_var2 = T.tensor4('inputs')
    target_var2 = T.tensor4('targets')
    rand_num = T.tensor4('rand_num')

    eps = 1e-18
    alpha = 0.001

    network2 = fun(input_var=input_var2,
                   batch_size=train.batch_size,
                   time_context=train.time_context,
                   feat_size=train.input_size)

    if load:
        params = load_model(model)
        lasagne.layers.set_all_param_values(network2, params)

    prediction2 = lasagne.layers.get_output(network2, deterministic=True)

    rand_num = np.random.uniform(size=(train.batch_size, 1, train.time_context,
                                       train.input_size))

    s1 = prediction2[:, 0:1, :, :]
    s2 = prediction2[:, 1:2, :, :]
    s3 = prediction2[:, 2:3, :, :]
    s4 = prediction2[:, 3:4, :, :]

    mask1 = s1 / (s1 + s2 + s3 + s4 + eps * rand_num)
    mask2 = s2 / (s1 + s2 + s3 + s4 + eps * rand_num)
    mask3 = s3 / (s1 + s2 + s3 + s4 + eps * rand_num)
    mask4 = s4 / (s1 + s2 + s3 + s4 + eps * rand_num)

    source1 = mask1 * input_var2[:, 0:1, :, :]
    source2 = mask2 * input_var2[:, 0:1, :, :]
    source3 = mask3 * input_var2[:, 0:1, :, :]
    source4 = mask4 * input_var2[:, 0:1, :, :]

    train_loss_recon1 = lasagne.objectives.squared_error(
        source1, target_var2[:, 0:1, :, :])
    train_loss_recon2 = lasagne.objectives.squared_error(
        source2, target_var2[:, 1:2, :, :])
    train_loss_recon3 = lasagne.objectives.squared_error(
        source3, target_var2[:, 2:3, :, :])
    train_loss_recon4 = lasagne.objectives.squared_error(
        source4, target_var2[:, 3:4, :, :])

    error1 = train_loss_recon1.sum()
    error2 = train_loss_recon2.sum()
    error3 = train_loss_recon3.sum()
    error4 = train_loss_recon4.sum()

    loss = abs(error1 + error2 + error3 + error4)

    params1 = lasagne.layers.get_all_params(network2, trainable=True)

    updates = lasagne.updates.adadelta(loss, params1)

    train_fn = theano.function([input_var2, target_var2],
                               loss,
                               updates=updates,
                               allow_input_downcast=True)

    train_fn1 = theano.function([input_var2, target_var2],
                                [error1, error2, error3, error4],
                                allow_input_downcast=True)

    predict_function2 = theano.function([input_var2],
                                        [source1, source2, source3, source4],
                                        allow_input_downcast=True)

    losser = []

    if not skip_train:

        logging.info("Training...")
        for epoch in range(num_epochs):

            train_err = 0
            train_batches = 0
            err1 = 0
            err2 = 0
            err3 = 0
            err4 = 0
            start_time = time.time()
            for batch in range(train.iteration_size):
                inputs, target = train()

                jump = inputs.shape[2]
                targets = np.ndarray(shape=(inputs.shape[0], 4,
                                            inputs.shape[1], inputs.shape[2]))
                inputs = np.reshape(
                    inputs,
                    (inputs.shape[0], 1, inputs.shape[1], inputs.shape[2]))

                targets[:, 0, :, :] = target[:, :, :jump]
                targets[:, 1, :, :] = target[:, :, jump:jump * 2]
                targets[:, 2, :, :] = target[:, :, jump * 2:jump * 3]
                targets[:, 3, :, :] = target[:, :, jump * 3:jump * 4]
                target = None
                #gc.collect()

                train_err += train_fn(inputs, targets)
                [e1, e2, e3, e4] = train_fn1(inputs, targets)
                err1 += e1
                err2 += e2
                err3 += e3
                err4 += e4
                train_batches += 1

            logging.info("Epoch {} of {} took {:.3f}s".format(
                epoch + 1, num_epochs,
                time.time() - start_time))
            logging.info("  training loss:\t\t{:.6f}".format(train_err /
                                                             train_batches))
            logging.info("  training loss for bassoon:\t\t{:.6f}".format(
                err1 / train_batches))
            logging.info("  training loss for clarinet:\t\t{:.6f}".format(
                err2 / train_batches))
            logging.info("  training loss for saxophone:\t\t{:.6f}".format(
                err3 / train_batches))
            logging.info("  training loss for violin:\t\t{:.6f}".format(
                err4 / train_batches))
            losser.append(train_err / train_batches)
            save_model(model, network2)

    if not skip_sep:

        logging.info("Separating")
        sources = ['bassoon', 'clarinet', 'saxphone', 'violin']
        sources_midi = ['bassoon', 'clarinet', 'saxophone', 'violin']

        for f in testfile_list:
            for i in range(len(sources)):
                filename = os.path.join(testdir, f,
                                        f + '-' + sources[i] + '.wav')
                audioObj, sampleRate, bitrate = util.readAudioScipy(filename)

                assert sampleRate == 44100, "Sample rate needs to be 44100"

                nframes = int(np.ceil(
                    len(audioObj) / np.double(tt.hopSize))) + 2
                if i == 0:
                    audio = np.zeros(audioObj.shape[0])
                    #melody = np.zeros((len(sources),1,nframes))
                audio = audio + audioObj
                audioObj = None

            mag, ph = transform.compute_file(audio, phase=True)
            mag = scale_factor * mag.astype(np.float32)

            batches, nchunks = util.generate_overlapadd(
                mag,
                input_size=mag.shape[-1],
                time_context=train.time_context,
                overlap=train.overlap,
                batch_size=train.batch_size,
                sampleRate=44100)
            output = []
            #output1=[]

            batch_no = 1
            for batch in batches:
                batch_no += 1
                start_time = time.time()
                output.append(predict_function2(batch))

            output = np.array(output)
            mm = util.overlapadd_multi(output,
                                       batches,
                                       nchunks,
                                       overlap=train.overlap)
            for i in range(len(sources)):
                audio_out = transform.compute_inverse(
                    mm[i, :len(ph)] / scale_factor, ph)
                if len(audio_out) > len(audio):
                    audio_out = audio_out[:len(audio)]
                util.writeAudioScipy(
                    os.path.join(outdir, f + '-' + sources[i] + '.wav'),
                    audio_out, sampleRate, bitrate)
                audio_out = None

        style = ['fast', 'slow', 'original']
        if not os.path.exists(outdir1):
            os.makedirs(outdir1)
        for s in style:
            for f in testfile_list1:
                for i in range(len(sources)):
                    filename = os.path.join(
                        testdir1, f,
                        f + '_' + s + '_' + sources_midi[i] + '.wav')
                    audioObj, sampleRate, bitrate = util.readAudioScipy(
                        filename)

                    assert sampleRate == 44100, "Sample rate needs to be 44100"

                    nframes = int(
                        np.ceil(len(audioObj) / np.double(tt.hopSize))) + 2

                    if i == 0:
                        audio = np.zeros(audioObj.shape[0])
                        #melody = np.zeros((len(sources),1,nframes))
                    audio = audio + audioObj
                    audioObj = None

                mag, ph = transform.compute_file(audio, phase=True)
                mag = scale_factor * mag.astype(np.float32)

                batches, nchunks = util.generate_overlapadd(
                    mag,
                    input_size=mag.shape[-1],
                    time_context=train.time_context,
                    overlap=train.overlap,
                    batch_size=train.batch_size,
                    sampleRate=44100)
                output = []

                batch_no = 1
                for batch in batches:
                    batch_no += 1
                    start_time = time.time()
                    output.append(predict_function2(batch))

                output = np.array(output)
                mm = util.overlapadd_multi(output,
                                           batches,
                                           nchunks,
                                           overlap=train.overlap)
                for i in range(len(sources)):
                    audio_out = transform.compute_inverse(
                        mm[i, :len(ph)] / scale_factor, ph)
                    if len(audio_out) > len(audio):
                        audio_out = audio_out[:len(audio)]
                    filename = os.path.join(
                        outdir1, f + '_' + s + '_' + sources_midi[i] + '.wav')
                    util.writeAudioScipy(filename, audio_out, sampleRate,
                                         bitrate)
                    audio_out = None

    return losser
def train_auto(fun,transform,testdir,outdir,testfile_list,testdir1,outdir1,testfile_list1,num_epochs=30,model="1.pkl",scale_factor=0.3,load=False,skip_train=False,skip_sep=False,
    path_transform_in=None,nsamples=40,batch_size=32, batch_memory=50, time_context=30, overlap=25, nprocs=4,mult_factor_in=0.3,mult_factor_out=0.3,timbre_model_path=None):
    """
    Trains a network built with \"fun\" with the data generated with \"train\"
    and then separates the files in \"testdir\",writing the result in \"outdir\"

    Parameters
    ----------
    fun : lasagne network object, Theano tensor
        The network to be trained
    transform : transformFFT object
        The Transform object which was used to compute the features (see compute_features.py)
    testdir : string, optional
        The directory where the files to be separated are located
    outdir : string, optional
        The directory where to write the separated files
    num_epochs : int, optional
        The number the epochs to train for (one epoch is when all examples in the dataset are seen by the network)
    model : string, optional
        The path where to save the trained model (theano tensor containing the network)
    scale_factor : float, optional
        Scale the magnitude of the files to be separated with this factor
    Yields
    ------
    losser : list
        The losses for each epoch, stored in a list
    """

    logging.info("Building Autoencoder")
    input_var2 = T.tensor4('inputs')
    target_var2 = T.tensor4('targets')
    rand_num = T.tensor4('rand_num')

    #parameters for the score-informed separation
    nharmonics=20
    interval=50 #cents
    tuning_freq=440 #Hz

    eps=1e-18
    alpha=0.001

    input_size = int(float(transform.frameSize) / 2 + 1)

    network2 = fun(input_var=input_var2,batch_size=batch_size,time_context=time_context,feat_size=input_size,nchannels=4)

    if load:
        params=load_model(model)
        lasagne.layers.set_all_param_values(network2,params)

    prediction2 = lasagne.layers.get_output(network2, deterministic=True)

    rand_num = np.random.uniform(size=(batch_size,1,time_context,input_size))

    s1=prediction2[:,0:1,:,:]
    s2=prediction2[:,1:2,:,:]
    s3=prediction2[:,2:3,:,:]
    s4=prediction2[:,3:4,:,:]

    mask1=s1/(s1+s2+s3+s4+eps*rand_num)
    mask2=s2/(s1+s2+s3+s4+eps*rand_num)
    mask3=s3/(s1+s2+s3+s4+eps*rand_num)
    mask4=s4/(s1+s2+s3+s4+eps*rand_num)

    input_var = input_var2[:,0:1,:,:] + input_var2[:,1:2,:,:] + input_var2[:,2:3,:,:] + input_var2[:,3:4,:,:]

    source1=mask1*input_var[:,0:1,:,:]
    source2=mask2*input_var[:,0:1,:,:]
    source3=mask3*input_var[:,0:1,:,:]
    source4=mask4*input_var[:,0:1,:,:]

    train_loss_recon1 = lasagne.objectives.squared_error(source1,target_var2[:,0:1,:,:])
    train_loss_recon2 = lasagne.objectives.squared_error(source2,target_var2[:,1:2,:,:])
    train_loss_recon3 = lasagne.objectives.squared_error(source3,target_var2[:,2:3,:,:])
    train_loss_recon4 = lasagne.objectives.squared_error(source4,target_var2[:,3:4,:,:])

    error1=train_loss_recon1.sum()
    error2=train_loss_recon2.sum()
    error3=train_loss_recon3.sum()
    error4=train_loss_recon4.sum()

    loss=abs(error1+error2+error3+error4)

    params1 = lasagne.layers.get_all_params(network2, trainable=True)

    updates = lasagne.updates.adadelta(loss, params1)

    train_fn = theano.function([input_var2,target_var2], loss, updates=updates,allow_input_downcast=True)

    train_fn1 = theano.function([input_var2,target_var2], [error1,error2,error3,error4], allow_input_downcast=True)

    predict_function2=theano.function([input_var2],[source1,source2,source3,source4],allow_input_downcast=True)

    losser=[]
    min_loss = 1e14

    training_steps = 0

    if not skip_train:

        logging.info("Training...")
        for epoch in range(num_epochs):
            train = LargeDatasetMask2(path_transform_in=path_in, nsources=4, nsamples=nsamples, batch_size=batch_size, batch_memory=batch_memory, time_context=time_context, overlap=overlap, nprocs=nprocs,mult_factor_in=scale_factor,mult_factor_out=scale_factor,\
                sampleRate=transform.sampleRate,pitch_code='e', nharmonics=20, pitch_norm=127.,tensortype=theano.config.floatX,timbre_model_path=timbre_model_path)
            train_err = 0
            train_batches = 0
            err1=0
            err2=0
            err3=0
            err4=0
            start_time = time.time()
            for batch in range(train.iteration_size):

                inputs, target, masks = train()
                jump = inputs.shape[2]

                mask=np.empty(shape=(inputs.shape[0],4,inputs.shape[1],inputs.shape[2]),dtype=theano.config.floatX)
                mask[:,0,:,:]=masks[:,:,:jump] * inputs
                mask[:,1,:,:]=masks[:,:,jump:jump*2] * inputs
                mask[:,2,:,:]=masks[:,:,jump*2:jump*3] * inputs
                mask[:,3,:,:]=masks[:,:,jump*3:jump*4] * inputs
                masks=None

                targets=np.empty(shape=(inputs.shape[0],4,inputs.shape[1],inputs.shape[2]),dtype=theano.config.floatX)
                targets[:,0,:,:]=target[:,:,:jump]
                targets[:,1,:,:]=target[:,:,jump:jump*2]
                targets[:,2,:,:]=target[:,:,jump*2:jump*3]
                targets[:,3,:,:]=target[:,:,jump*3:jump*4]
                target=None

                inputs=None

                train_err+=train_fn(mask,targets)
                [e1,e2,e3,e4]=train_fn1(mask,targets)
                err1 += e1
                err2 += e2
                err3 += e3
                err4 += e4
                train_batches += 1

            logging.info("Epoch {} of {} took {:.3f}s".format(
                epoch + 1, num_epochs, time.time() - start_time))
            logging.info("  training loss:\t\t{:.6f}".format(train_err/train_batches))
            logging.info("  training loss for bassoon:\t\t{:.6f}".format(err1/train_batches))
            logging.info("  training loss for clarinet:\t\t{:.6f}".format(err2/train_batches))
            logging.info("  training loss for saxophone:\t\t{:.6f}".format(err3/train_batches))
            logging.info("  training loss for violin:\t\t{:.6f}".format(err4/train_batches))
            losser.append(train_err / train_batches)
            #save_model(model,network2)
            # if (train_err/train_batches) < min_loss:
            #     min_loss = train_err/train_batches
            save_model(model,network2)

        # training_steps = training_steps + 1
        # num_epochs = int(np.ceil(float(num_epochs)/5.))

        # if losser[-1] > min_loss:
        #     params=load_model(model)
        #     lasagne.layers.set_all_param_values(network2,params,learning_rate=0.0001)

        # updates = lasagne.updates.adam(loss, params1)
        # train_fn = theano.function([input_var2,target_var2], loss, updates=updates,allow_input_downcast=True)


    if not skip_sep:

        logging.info("Separating")
        sources = ['bassoon','clarinet','saxphone','violin']
        sources_midi = ['bassoon','clarinet','saxophone','violin']

        train = LargeDatasetMask2(path_transform_in=path_in, nsources=4, batch_size=batch_size, batch_memory=batch_memory, time_context=time_context, overlap=overlap, nprocs=nprocs,mult_factor_in=scale_factor,mult_factor_out=scale_factor,\
                sampleRate=transform.sampleRate,pitch_code='e', nharmonics=20, pitch_norm=127.,tensortype=theano.config.floatX,timbre_model_path=timbre_model_path)

        for f in testfile_list:
            nelem_g=1
            for i in range(len(sources)):
                ng = util.getMidiNum(sources_midi[i]+'_b',os.path.join(testdir,f),0,40.0)
                nelem_g = np.maximum(ng,nelem_g)
            melody = np.zeros((len(sources),int(nelem_g),2*nharmonics+3))
            for i in range(len(sources)):
                filename=os.path.join(testdir,f,f+'-'+sources[i]+'.wav')
                audioObj, sampleRate, bitrate = util.readAudioScipy(filename)

                assert sampleRate == 44100,"Sample rate needs to be 44100"

                nframes = int(np.ceil(len(audioObj) / np.double(tt.hopSize))) + 2
                if i==0:
                    audio = np.zeros(audioObj.shape[0])
                audio = audio + audioObj
                audioObj=None

                tmp = util.expandMidi(sources_midi[i]+'_b',os.path.join(testdir,f),0,40.0,interval,tuning_freq,nharmonics,sampleRate,tt.hopSize,tt.frameSize,0.2,0.2,nframes,0.5)
                melody[i,:tmp.shape[0],:] = tmp
                tmp = None

            mag,ph=transform.compute_file(audio,phase=True)
            mag=scale_factor*mag.astype(np.float32)

            jump = mag.shape[-1]

            masks_temp = train.filterSpec(mag,melody,0,nframes)
            masks = np.ones((train.ninst,mag.shape[0],mag.shape[1]))
            masks[0,:,:]=masks_temp[:,:jump] * mag
            masks[1,:,:]=masks_temp[:,jump:jump*2] * mag
            masks[2,:,:]=masks_temp[:,jump*2:jump*3] * mag
            masks[3,:,:]=masks_temp[:,jump*3:jump*4] * mag
            mag = None
            masks_temp = None

            batches,nchunks = util.generate_overlapadd(masks,input_size=masks.shape[-1],time_context=train.time_context,overlap=train.overlap,batch_size=train.batch_size,sampleRate=44100)
            masks = None

            batch_no=1
            output=[]
            for batch in batches:
                batch_no+=1
                #start_time=time.time()
                output.append(predict_function2(batch))

            output=np.array(output)
            mm=util.overlapadd_multi(output,batches,nchunks,overlap=train.overlap)
            for i in range(len(sources)):
                audio_out=transform.compute_inverse(mm[i,:len(ph)]/scale_factor,ph)
                if len(audio_out)>len(audio):
                    audio_out=audio_out[:len(audio)]
                util.writeAudioScipy(os.path.join(outdir,f+'-'+sources[i]+'.wav'),audio_out,sampleRate,bitrate)
                audio_out=None

        # style = ['fast','slow','original']
        # style_midi = ['_fast20','_slow20','_original']
        # if not os.path.exists(outdir1):
        #     os.makedirs(outdir1)
        # for s in range(len(style)):
        #     for f in testfile_list1:
        #         nelem_g=1
        #         for i in range(len(sources)):
        #             ng = util.getMidiNum(sources_midi[i]+'_g'+style_midi[s],os.path.join(testdir1,f),0,40.0)
        #             nelem_g = np.maximum(ng,nelem_g)
        #         melody = np.zeros((len(sources),int(nelem_g),2*nharmonics+3))
        #         for i in range(len(sources)):
        #             filename=os.path.join(testdir1,f,f+'_'+style[s]+'_'+sources_midi[i]+'.wav')

        #             audioObj, sampleRate, bitrate = util.readAudioScipy(filename)

        #             assert sampleRate == 44100,"Sample rate needs to be 44100"

        #             nframes = int(np.ceil(len(audioObj) / np.double(tt.hopSize))) + 2

        #             if i==0:
        #                 audio = np.zeros(audioObj.shape[0])

        #             audio = audio + audioObj
        #             audioObj=None

        #             tmp = util.expandMidi(sources_midi[i]+'_g'+style_midi[s],os.path.join(testdir1,f),0,40.0,interval,tuning_freq,nharmonics,sampleRate,tt.hopSize,tt.frameSize,0.2,0.2,nframes)
        #             melody[i,:tmp.shape[0],:] = tmp
        #             tmp = None

        #         mag,ph=transform.compute_file(audio,phase=True)
        #         mag=scale_factor*mag.astype(np.float32)

        #         jump = mag.shape[-1]

        #         masks_temp = train.filterSpec(mag,melody,0,nframes)
        #         masks = np.ones((train.ninst,mag.shape[0],mag.shape[1]))
        #         masks[0,:,:]=masks_temp[:,:jump] * mag
        #         masks[1,:,:]=masks_temp[:,jump:jump*2] * mag
        #         masks[2,:,:]=masks_temp[:,jump*2:jump*3] * mag
        #         masks[3,:,:]=masks_temp[:,jump*3:jump*4] * mag
        #         mag = None
        #         masks_temp = None

        #         batches,nchunks = util.generate_overlapadd(masks,input_size=masks.shape[-1],time_context=train.time_context,overlap=train.overlap,batch_size=train.batch_size,sampleRate=44100)
        #         masks = None

        #         batch_no=1
        #         output=[]
        #         for batch in batches:
        #             batch_no+=1
        #             #start_time=time.time()
        #             output.append(predict_function2(batch))

        #         output=np.array(output)
        #         mm=util.overlapadd_multi(output,batches,nchunks,overlap=train.overlap)
        #         for i in range(len(sources)):
        #             audio_out=transform.compute_inverse(mm[i,:len(ph)]/scale_factor,ph)
        #             if len(audio_out)>len(audio):
        #                 audio_out=audio_out[:len(audio)]
        #             filename=os.path.join(outdir1,f+'_'+style[s]+'_'+sources_midi[i]+'.wav')
        #             util.writeAudioScipy(filename,audio_out,sampleRate,bitrate)
        #             audio_out=None

    return losser