def train():
    # load data
    batch_size = 128
    tr_X, tr_y, va_X, va_y, te_X, te_y = pp_data.load_data()
    n_batches = int(tr_X.shape[0] / batch_size)

    # normalize data between [-1,1]
    tr_X = (tr_X - 0.5) * 2
    tr_X = tr_X.reshape((50000, 1, 28, 28))
    print tr_X.shape

    # generator
    a0 = InputLayer(100)
    a1 = Dense(128 * 7 * 7, act='linear')(a0)
    a1 = BN(axis=0)(a1)
    a1 = Reshape(out_shape=(128, 7, 7))(a1)
    a1 = Convolution2D(64, 5, 5, act='linear', border_mode=(2, 2))(a1)
    a1 = BN(axis=(0, 2, 3))(a1)
    a1 = Activation('leaky_relu')(a1)
    a1 = UpSampling2D(size=(2, 2))(a1)
    a1 = Convolution2D(32, 5, 5, act='linear', border_mode=(2, 2))(a1)
    a1 = BN(axis=(0, 2, 3))(a1)
    a1 = Activation('leaky_relu')(a1)
    a1 = UpSampling2D(size=(2, 2))(a1)
    a8 = Convolution2D(1, 5, 5, act='tanh', border_mode=(2, 2), name='a8')(a1)

    g = Model([a0], [a8])
    g.compile()
    g.summary()

    # discriminator
    b0 = InputLayer((1, 28, 28), name='b0')
    b1 = Convolution2D(64, 5, 5, act='relu', border_mode=(0, 0), name='b1')(b0)
    b1 = MaxPooling2D(pool_size=(2, 2))(b1)
    b1 = Convolution2D(128, 5, 5, act='relu', border_mode=(0, 0))(b1)
    b1 = MaxPooling2D(pool_size=(2, 2))(b1)
    b1 = Flatten()(b1)
    b8 = Dense(1, act='sigmoid')(b1)
    d = Model([b0], [b8])
    d.compile()
    d.summary()

    # discriminator on generator
    d_on_g = Model()
    d.set_trainability(False)
    d_on_g.add_models([g, d])
    d.set_trainability(True)
    d_on_g.joint_models('a8', 'b0')
    d_on_g.compile()
    d_on_g.summary()

    # optimizer
    opt_d = Adam(1e-4)
    opt_g = Adam(1e-4)

    # optimization function
    f_train_d = d.get_optimization_func(target_dims=[2],
                                        loss_func='binary_crossentropy',
                                        optimizer=opt_d,
                                        clip=None)
    f_train_g = d_on_g.get_optimization_func(target_dims=[2],
                                             loss_func='binary_crossentropy',
                                             optimizer=opt_g,
                                             clip=None)

    noise = np.zeros((batch_size, 100))
    for epoch in range(100):
        print epoch
        for index in range(n_batches):
            # concatenate generated img and real image to train discriminator.
            noise = np.random.uniform(-1, 1, (batch_size, 100))
            batch_x = tr_X[index * batch_size:(index + 1) * batch_size]
            batch_gx = g.predict(noise)
            batch_x_all = np.concatenate((batch_x, batch_gx))

            # assign real img label as 1, generated img label as 0
            batch_y_all = np.array([1] * batch_size + [0] * batch_size)
            batch_y_all = batch_y_all.reshape((batch_y_all.shape[0], 1))

            # save out generated img
            if index % 50 == 0:
                image = pp_data.combine_images(batch_gx)
                image = image * 127.5 + 127.5
                if not os.path.exists("img_dcgan"): os.makedirs("img_dcgan")
                Image.fromarray(image.astype(
                    np.uint8)).save("img_dcgan/" + str(epoch) + "_" +
                                    str(index) + ".png")

            # train discriminator
            d_loss = d.train_on_batch(f_train_d, batch_x_all, batch_y_all)

            # assign generate img label as 1, so as to deceive discriminator
            noise = np.random.uniform(-1, 1, (batch_size, 100))
            batch_y_all = np.array([1] * batch_size)
            batch_y_all = batch_y_all.reshape((batch_y_all.shape[0], 1))

            # train generator
            g_loss = d_on_g.train_on_batch(f_train_g, noise, batch_y_all)
            print index, "d_loss:", d_loss, "\tg_loss:", g_loss
def train():
    _loss_func = _jdc_loss_func0
    
    # load data
    t1 = time.time()
    dict = cPickle.load( open( cfg.scrap_fd+'/denoise_enhance_pool_fft_all0.p', 'rb' ) )
    tr_X, tr_mask, tr_y, tr_na_list, te_X, te_mask, te_y, te_na_list = dict['tr_X'], dict['tr_mask'], dict['tr_y'], dict['tr_na_list'], dict['te_X'], dict['te_mask'], dict['te_y'], dict['te_na_list']
    t2 = time.time()
    
    tr_X = pp_data.wipe_click( tr_X, tr_na_list )
    te_X = pp_data.wipe_click( te_X, te_na_list )
    
    # balance data
    tr_X, tr_mask, tr_y = pp_data.BalanceData2( tr_X, tr_mask, tr_y )
    te_X, te_mask, te_y = pp_data.BalanceData2( te_X, te_mask, te_y )
    
    
    print tr_X.shape, tr_y.shape, te_X.shape, te_y.shape
    [n_songs, n_chunks, n_freq] = te_X.shape
    
    tr_y = tr_y.reshape( (len(tr_y), 1) )
    te_y = te_y.reshape( (len(te_y), 1) )
    
    
    # jdc model
    # classifier
    lay_z0 = InputLayer( (n_chunks,) )          # shape:(n_songs, n_chunks) keep the length of songs
    
    lay_in0 = InputLayer( (n_chunks, n_freq), name='in0' )   # shape: (n_songs, n_chunk, n_freq)
    lay_a1 = lay_in0
    # lay_a1 = Lambda( _conv2d )( lay_a1 )
    
    lay_a1 = Lambda( _reshape_3d_to_4d )( lay_a1 )
    lay_a1 = Convolution2D( 32, 3, 3, act='relu', init_type='glorot_uniform', border_mode=(1,1), strides=(1,1), name='a11' )( lay_a1 )
    lay_a1 = Dropout( 0.2 )( lay_a1 )
    lay_a1 = MaxPool2D( pool_size=(1,2) )( lay_a1 )
    
    lay_a1 = Convolution2D( 64, 3, 3, act='relu', init_type='glorot_uniform', border_mode=(1,1), strides=(1,1), name='a12' )( lay_a1 )
    lay_a1 = Dropout( 0.2 )( lay_a1 )
    lay_a1 = MaxPool2D( pool_size=(1,2) )( lay_a1 )
    lay_a1 = Lambda( _reshape_4d_to_3d )( lay_a1 )
    
    lay_a1 = Dense( n_hid, act='relu', name='a2' )( lay_a1 )       # shape: (n_songs, n_chunk, n_hid)
    lay_a1 = Dropout( 0.2 )( lay_a1 )
    lay_a1 = Dense( n_hid, act='relu', name='a4' )( lay_a1 )
    lay_a1 = Dropout( 0.2 )( lay_a1 )
    lay_a1 = Dense( n_hid, act='relu', name='a6' )( lay_a1 )
    lay_a1 = Dropout( 0.2 )( lay_a1 )
    lay_a8 = Dense( n_out, act='sigmoid', init_type='zeros', b_init=0, name='a8' )( lay_a1 )     # shape: (n_songs, n_chunk, n_out)
    
    # detector
    lay_b1 = lay_in0     # shape: (n_songs, n_chunk, n_freq)
    lay_b2 = Lambda( _conv2d )( lay_b1 )    # shape: (n_songs, n_chunk, n_freq)
    lay_b2 = Lambda( _reshape_3d_to_4d )( lay_b1 )
    lay_b2 = MaxPool2D( pool_size=(1,2) )( lay_b2 )
    lay_b2 = Lambda( _reshape_4d_to_3d )( lay_b2 )
    lay_b8 = Dense( n_out, act='hard_sigmoid', init_type='zeros', b_init=-2.3, name='b8' )( lay_b2 )
    md = Model( in_layers=[lay_in0, lay_z0], out_layers=[lay_a8, lay_b8], any_layers=[] )
    
      
    # print summary info of model
    md.summary()

    # callbacks (optional)
    # save model every n epoch (optional)
    pp_data.CreateFolder( cfg.wbl_dev_md_fd )
    pp_data.CreateFolder( cfg.wbl_dev_md_fd+'/cnn_fft' )
    save_model = SaveModel( dump_fd=cfg.wbl_dev_md_fd+'/cnn_fft', call_freq=20, type='iter' )
    validation = Validation( tr_x=None, tr_y=None, va_x=None, va_y=None, te_x=[te_X, te_mask], te_y=te_y, batch_size=100, metrics=[_loss_func], call_freq=20, dump_path=None, type='iter' )
    
    # callbacks function
    callbacks = [save_model, validation]

    
    # EM training
    md.set_gt_nodes( tr_y )
    md.find_layer('a11').set_trainable_params( ['W','b'] )
    md.find_layer('a12').set_trainable_params( ['W','b'] )
    md.find_layer('a2').set_trainable_params( ['W','b'] )
    md.find_layer('a4').set_trainable_params( ['W','b'] )
    md.find_layer('a6').set_trainable_params( ['W','b'] )
    md.find_layer('a8').set_trainable_params( ['W','b'] )
    md.find_layer('b8').set_trainable_params( [] )
    opt_classifier = Adam( 1e-3 )
    f_classify = md.get_optimization_func( loss_func=_loss_func, optimizer=opt_classifier, clip=None )
    
    
    md.find_layer('a11').set_trainable_params( [] )
    md.find_layer('a12').set_trainable_params( [] )
    md.find_layer('a2').set_trainable_params( [] )
    md.find_layer('a4').set_trainable_params( [] )
    md.find_layer('a6').set_trainable_params( [] )
    md.find_layer('a8').set_trainable_params( [] )
    md.find_layer('b8').set_trainable_params( ['W','b'] )
    opt_detector = Adam( 1e-3 )
    f_detector = md.get_optimization_func( loss_func=_loss_func, optimizer=opt_detector, clip=None )
    
    
    _x, _y = md.preprocess_data( [tr_X, tr_mask], tr_y, shuffle=True )
    
    for i1 in xrange(500):
        print '-----------------------'
        opt_classifier.reset()
        md.do_optimization_func_iter_wise( f_classify, _x, _y, batch_size=100, n_iters=80, callbacks=callbacks, verbose=1 )
        print '-----------------------'
        opt_detector.reset()
        md.do_optimization_func_iter_wise( f_detector, _x, _y, batch_size=100, n_iters=20, callbacks=callbacks, verbose=1 )