def main(args=None):
    """Generate random periodic time series and train an autoencoder model.
    
    args: dict
        Dictionary of values to override default values in `keras_util.parse_model_args`;
        can also be passed via command line. See `parse_model_args` for full list of
        possible arguments.
    """
    args = ku.parse_model_args(args)

    train = np.arange(args.N_train); test = np.arange(args.N_test) + args.N_train
    X, Y, X_raw = sample_data.periodic(args.N_train + args.N_test, args.n_min, args.n_max,
                                       even=args.even, noise_sigma=args.sigma,
                                       kind=args.data_type)

    if args.even:
        X = X[:, :, 1:2]
        X_raw = X_raw[:, :, 1:2]
    else:
        X[:, :, 0] = ku.times_to_lags(X_raw[:, :, 0])
        X[np.isnan(X)] = -1.
        X_raw[np.isnan(X_raw)] = -1.

    model_type_dict = {'gru': GRU, 'lstm': LSTM, 'vanilla': SimpleRNN}

    main_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    if args.even:
        model_input = main_input
        aux_input = None
    else:
        aux_input = Input(shape=(X.shape[1], X.shape[-1] - 1), name='aux_input')
        model_input = [main_input, aux_input]

    encode = encoder(main_input, layer=model_type_dict[args.model_type],
                     output_size=args.embedding, **vars(args))
    decode = decoder(encode, num_layers=args.decode_layers if args.decode_layers
                                                           else args.num_layers,
                     layer=model_type_dict[args.decode_type if args.decode_type
                                           else args.model_type],
                     n_step=X.shape[1], aux_input=aux_input,
                     **{k: v for k, v in vars(args).items() if k != 'num_layers'})
    model = Model(model_input, decode)

    run = ku.get_run_id(**vars(args))
 
    if args.even:
        history = ku.train_and_log(X[train], X_raw[train], run, model, **vars(args))
    else:
        sample_weight = (X[train, :, -1] != -1)
        history = ku.train_and_log({'main_input': X[train], 'aux_input': X[train, :, 0:1]},
                                   X_raw[train, :, 1:2], run, model,
                                   sample_weight=sample_weight, **vars(args))
    return X, Y, X_raw, model, args
示例#2
0
def main(args=None):
    args = ku.parse_model_args(args)

    N = args.N_train + args.N_test
    train = np.arange(args.N_train)
    test = np.arange(args.N_test) + args.N_train
    X, Y, X_raw = sample_data.periodic(N,
                                       args.n_min,
                                       args.n_max,
                                       even=args.even,
                                       noise_sigma=args.sigma,
                                       kind=args.data_type)

    if args.even:
        X = X[:, :, 1:2]
    else:
        X[:, :, 0] = ku.times_to_lags(X_raw[:, :, 0])
        X[np.isnan(X)] = -1.
        X_raw[np.isnan(X_raw)] = -1.

    Y = sample_data.phase_to_sin_cos(Y)
    scaler = StandardScaler(copy=False, with_mean=True, with_std=True)
    scaler.fit_transform(Y)
    if args.loss_weights:  # so far, only used to zero out some columns
        Y *= args.loss_weights

    model_type_dict = {'gru': GRU, 'lstm': LSTM, 'vanilla': SimpleRNN}

    model_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    encode = encoder(model_input,
                     layer=model_type_dict[args.model_type],
                     output_size=Y.shape[-1],
                     **vars(args))
    model = Model(model_input, encode)

    run = ku.get_run_id(**vars(args))

    history = ku.train_and_log(X[train], Y[train], run, model, **vars(args))
    return X, Y, X_raw, scaler, model, args
示例#3
0
def main(args=None):
    args = parse_model_args(args)
    K.set_floatx('float64')

    run = get_run_id(**vars(args))
    log_dir = os.path.join(os.getcwd(), 'keras_logs', args.sim_type,
                           run)  #Loading the model architecture
    weights_path = os.path.join(log_dir,
                                'weights.h5')  #Loading the model weights

    print("log_dir", log_dir)
    print("Weight matrix read...")

    #Load the model
    #How do I access dict using index? I just want to load the model
    #Why not use the main args?
    model = list(survey_autoencoder(vars(args)))[2]

    #LOADING GMM PARAMTERS
    # Where is gmm.mu updated?
    gmm_para = np.load(log_dir + '/gmm_parameters')
    gmm_mu = gmm_para[gmm_mu]  #Size = embedding size * #classes

    #LOADING THE MODEL
    decode_model = Model(inputs=model.input,
                         outputs=model.get_layer('time_dist').output)

    #What is X[valid], X[new]
    #TRAINING SAMPLES
    gmm_mu = np.float64(gmm_mu)
    #NO AUX INPUT FOR DECODER! DOES IT ASSUME EVEN INPUTS?
    decoding_train = decode_model.predict(gmm_mu)
    print(decoding_train.shape)

    phase = np.linspace(0, 1, len(decoding_train[0]))
    for i in range(len(decoding_train)):
        plt.plot(phase, decoding_train[i])
def main(args=None):
    """Train an autoencoder model from `LightCurve` objects saved in
    `args.survey_files`.
    
    args: dict
        Dictionary of values to override default values in `keras_util.parse_model_args`;
        can also be passed via command line. See `parse_model_args` for full list of
        possible arguments.
    """
    args = ku.parse_model_args(args)

    np.random.seed(0)

    K.set_floatx('float64')

    run = ku.get_run_id(**vars(args))

    if not args.survey_files:
        raise ValueError("No survey files given")

    lc_lists = [joblib.load(f) for f in args.survey_files]
    n_reps = [max(len(y) for y in lc_lists) // len(x) for x in lc_lists]
    combined = sum([x * i for x, i in zip(lc_lists, n_reps)], [])

    # Preparation for classification probability
    classprob_pkl = joblib.load("./data/asassn/class_probs.pkl")
    class_probability = dict(classprob_pkl)

    # Combine subclasses into eight superclasses:
    # CEPH, DSCT, ECL, RRAB, RRCD, M, ROT, SR
    for lc in combined:
      if ((lc.label == 'EW') or (lc.label == 'EA') or (lc.label == 'EB')):
        lc.label = 'ECL'
      if ((lc.label == 'CWA') or (lc.label == 'CWB') or (lc.label == 'DCEP') or
          (lc.label == 'DCEPS') or (lc.label == 'RVA')):
        lc.label = "CEPH"
      if ((lc.label == 'DSCT') or (lc.label == 'HADS')):
        lc.label = "DSCT"
      if ((lc.label == 'RRD') or (lc.label == 'RRC')):
        lc.label = "RRCD"
    top_classes = ['SR', 'RRAB', 'RRCD', 'M', 'ROT', 'ECL', 'CEPH', 'DSCT']

    print ("Number of raw LCs:", len(combined))

    if args.lomb_score:
        combined = [lc for lc in combined if lc.best_score >= args.lomb_score]
    if args.ss_resid:
        combined = [lc for lc in combined if lc.ss_resid <= args.ss_resid]
    if args.class_prob:
        combined = [lc for lc in combined if float(class_probability[lc.name.split("/")[-1][2:-4]]) >= args.class_prob]
        #combined = [lc for lc in combined if lc.class_prob >= args.class_prob]

    # Select only superclasses for training
    combined = [lc for lc in combined if lc.label in top_classes]

    split = [el for lc in combined for el in lc.split(args.n_min, args.n_max)]
    if args.period_fold:
        for lc in split:
            lc.period_fold()

    X_list = [np.c_[lc.times, lc.measurements, lc.errors] for lc in split]
    classnames, indices = np.unique([lc.label for lc in split], return_inverse=True)
    y = classnames[indices]
    periods = np.array([np.log10(lc.p) for lc in split])
    periods = periods.reshape(len(split), 1)

    X_raw = pad_sequences(X_list, value=np.nan, dtype='float64', padding='post')

    model_type_dict = {'gru': GRU, 'lstm': LSTM, 'vanilla': SimpleRNN}
    X, means, scales, wrong_units, X_err = preprocess(X_raw, args.m_max)

    y = y[~wrong_units]
    periods = periods[~wrong_units]

    # Prepare the indices for training and validation
    train, valid = list(StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED).split(X, y))[0]

    X_valid = X[valid]
    y_valid = y[valid]
    means_valid = means[valid]
    scales_valid = scales[valid]
    periods_valid = periods[valid]
    energy_dummy_valid = np.zeros((X_valid.shape[0], 1))
    X_err_valid = X_err[valid]
    sample_weight_valid = 1. / X_err_valid

    X = X[train]
    y = y[train]
    means = means[train]
    scales = scales[train]
    periods = periods[train]
    energy_dummy = np.zeros((X.shape[0], 1))
    X_err = X_err[train]
    sample_weight = 1. / X_err

    supports_valid = np.concatenate((means_valid, scales_valid, periods_valid), axis=1)
    supports = np.concatenate((means, scales, periods), axis=1)

    num_supports_train = supports.shape[-1]
    num_supports_valid = supports_valid.shape[-1]
    assert(num_supports_train == num_supports_valid)
    num_additional = num_supports_train + 2 # 2 for reconstruction error

    if (args.gmm_on):
      gmm = GMM(args.num_classes, args.embedding+num_additional)

    ### Covert labels into one-hot vectors for training
    label_encoder = LabelEncoder()
    label_encoder.fit(y)
    ### Transform the integers into one-hot vector for softmax
    train_y_encoded = label_encoder.transform(y)
    train_y   = np_utils.to_categorical(train_y_encoded)
    ### Repeat for validation dataset 
    label_encoder = LabelEncoder()
    label_encoder.fit(y_valid)
    ### Transform the integers into one-hot vector for softmax
    valid_y_encoded = label_encoder.transform(y_valid)
    valid_y   = np_utils.to_categorical(valid_y_encoded)

    main_input = Input(shape=(X.shape[1], 2), name='main_input') # dim: (200, 2) = dt, mag
    aux_input  = Input(shape=(X.shape[1], 1), name='aux_input') # dim: (200, 1) = dt's

    model_input = [main_input, aux_input]
    if (args.gmm_on):
      support_input = Input(shape=(num_supports_train,), name='support_input') 
      model_input = [main_input, aux_input, support_input]

    encode = encoder(main_input, layer=model_type_dict[args.model_type], 
                     output_size=args.embedding, **vars(args))
    decode = decoder(encode, num_layers=args.decode_layers if args.decode_layers
                                                           else args.num_layers,
                     layer=model_type_dict[args.decode_type if args.decode_type
                                           else args.model_type],
                     n_step=X.shape[1], aux_input=aux_input,
                     **{k: v for k, v in vars(args).items() if k != 'num_layers'})
    optimizer = Adam(lr=args.lr if not args.finetune_rate else args.finetune_rate)

    if (not args.gmm_on):
      model = Model(model_input, decode)

      model.compile(optimizer=optimizer, loss='mse',
                    sample_weight_mode='temporal')

    else: 
      est_net = EstimationNet(args.estnet_size, args.num_classes, K.tanh)

      # extract x_i and hat{x}_{i} for reconstruction error feature calculations
      main_input_slice = Lambda(lambda x: x[:,:,1], name="main_slice")(main_input)
      decode_slice = Lambda(lambda x: x[:,:,0], name="decode_slice")(decode)
      z_both = Lambda(extract_features, name="concat_zs")([main_input_slice, decode_slice, encode, support_input])

      gamma = est_net.inference(z_both, args.estnet_drop_frac)
      print ("z_both shape", z_both.shape)
      print ("gamma shape", gamma.shape)

      sigma_i = Lambda(lambda x: gmm.fit(x[0], x[1]), 
                       output_shape=(args.num_classes, args.embedding+num_additional, args.embedding+num_additional), 
                       name="sigma_i")([z_both, gamma])

      energy = Lambda(lambda x: gmm.energy(x[0], x[1]), name="energy")([z_both, sigma_i])

      model_output = [decode, gamma, energy]

      model = Model(model_input, model_output)

      optimizer = Adam(lr=args.lr if not args.finetune_rate else args.finetune_rate)

      model.compile(optimizer=optimizer, 
                    loss=['mse', 'categorical_crossentropy', energy_loss],
                    loss_weights=[1.0, 1.0, args.lambda1],
                    metrics={'gamma':'accuracy'},
                    sample_weight_mode=['temporal', None, None])

      # summary for checking dimensions
      print(model.summary())


    if (args.gmm_on): 
      history = ku.train_and_log( {'main_input':X, 'aux_input':np.delete(X, 1, axis=2), 'support_input':supports},
                                  {'time_dist':X[:, :, [1]], 'gamma':train_y, 'energy':energy_dummy}, run, model, 
                                  sample_weight={'time_dist':sample_weight, 'gamma':None, 'energy':None},
                                  validation_data=(
                                  {'main_input':X_valid, 'aux_input':np.delete(X_valid, 1, axis=2), 'support_input':supports_valid}, 
                                  {'time_dist':X_valid[:, :, [1]], 'gamma':valid_y, 'energy':energy_dummy_valid},
                                  {'time_dist':sample_weight_valid, 'gamma':None, 'energy':None}), 
                                  **vars(args))
    else: # Just train RNN autoencoder
      history = ku.train_and_log({'main_input':X, 'aux_input':np.delete(X, 1, axis=2)},
                                  X[:, :, [1]], run, model,
                                  sample_weight=sample_weight,
                                  validation_data=(
                                  {'main_input':X_valid, 'aux_input':np.delete(X_valid, 1, axis=2)},
                                  X_valid[:, :, [1]], sample_weight_valid), **vars(args))

    return X, X_raw, model, means, scales, wrong_units, args
示例#5
0
def main(args=None):
    args = ku.parse_model_args(args)

    args.loss = 'categorical_crossentropy'

    np.random.seed(0)

    if not args.survey_files:
        raise ValueError("No survey files given")
    classes = [
        'RR_Lyrae_FM', 'W_Ursae_Maj', 'Classical_Cepheid', 'Beta_Persei',
        'Semireg_PV'
    ]
    lc_lists = [joblib.load(f) for f in args.survey_files]
    combined = [lc for lc_list in lc_lists for lc in lc_list]
    combined = [lc for lc in combined if lc.label in classes]
    if args.lomb_score:
        combined = [lc for lc in combined if lc.best_score >= args.lomb_score]
    split = [el for lc in combined for el in lc.split(args.n_min, args.n_max)]
    if args.period_fold:
        for lc in split:
            lc.period_fold()
    X_list = [np.c_[lc.times, lc.measurements, lc.errors] for lc in split]

    classnames, y_inds = np.unique([lc.label for lc in split],
                                   return_inverse=True)
    Y = to_categorical(y_inds, len(classnames))

    X_raw = pad_sequences(X_list, value=np.nan, dtype='float', padding='post')
    X, means, scales, wrong_units = preprocess(X_raw, args.m_max)
    Y = Y[~wrong_units]

    # Remove errors
    X = X[:, :, :2]

    if args.even:
        X = X[:, :, 1:]


#    shuffled_inds = np.random.permutation(np.arange(len(X)))
#    train = np.sort(shuffled_inds[:args.N_train])
#    valid = np.sort(shuffled_inds[args.N_train:])
    train, valid = list(
        StratifiedKFold(n_splits=5, shuffle=True,
                        random_state=0).split(X_list, y_inds))[0]

    model_type_dict = {
        'gru': GRU,
        'lstm': LSTM,
        'vanilla': SimpleRNN,
        'conv': Conv1D
    }  #, 'atrous': AtrousConv1D, 'phased': PhasedLSTM}

    #    if args.pretrain:
    #        auto_args = {k: v for k, v in args.__dict__.items() if k != 'pretrain'}
    #        auto_args['sim_type'] = args.pretrain
    ##        auto_args['no_train'] = True
    #        auto_args['epochs'] = 1; auto_args['loss'] = 'mse'; auto_args['batch_size'] = 32; auto_args['sim_type'] = 'test'
    #        _, _, auto_model, _ = survey_autoencoder(auto_args)
    #        for layer in auto_model.layers:
    #            layer.trainable = False
    #        model_input = auto_model.input[0]
    #        encode = auto_model.get_layer('encoding').output
    #    else:
    #        model_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    #        encode = encoder(model_input, layer=model_type_dict[args.model_type],
    #                         output_size=args.embedding, **vars(args))
    model_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    encode = encoder(model_input,
                     layer=model_type_dict[args.model_type],
                     output_size=args.embedding,
                     **vars(args))

    scale_param_input = Input(shape=(2, ), name='scale_params')
    merged = merge([encode, scale_param_input], mode='concat')

    out = Dense(args.size + 2, activation='relu')(merged)
    out = Dense(Y.shape[-1], activation='softmax')(out)
    model = Model([model_input, scale_param_input], out)

    run = ku.get_run_id(**vars(args))
    if args.pretrain:
        for layer in model.layers:
            layer.trainable = False
        pretrain_weights = os.path.join('keras_logs', args.pretrain, run,
                                        'weights.h5')
    else:
        pretrain_weights = None

    history = ku.train_and_log(
        [X[train], np.c_[means, scales][train]],
        Y[train],
        run,
        model,
        metrics=['accuracy'],
        validation_data=([X[valid], np.c_[means, scales][valid]], Y[valid]),
        pretrain_weights=pretrain_weights,
        **vars(args))
    return X, X_raw, Y, model, args
def main(args=None):
    """Train an autoencoder model from `LightCurve` objects saved in
    `args.survey_files`.
    
    args: dict
        Dictionary of values to override default values in `keras_util.parse_model_args`;
        can also be passed via command line. See `parse_model_args` for full list of
        possible arguments.
    """
    args = ku.parse_model_args(args)

    np.random.seed(0)

    if not args.survey_files:
        raise ValueError("No survey files given")
    lc_lists = [joblib.load(f) for f in args.survey_files]
    n_reps = [max(len(y) for y in lc_lists) // len(x) for x in lc_lists]
    combined = sum([x * i for x, i in zip(lc_lists, n_reps)], [])
    if args.lomb_score:
        combined = [lc for lc in combined if lc.best_score >= args.lomb_score]
    if args.ss_resid:
        combined = [lc for lc in combined if lc.ss_resid <= args.ss_resid]
    split = [el for lc in combined for el in lc.split(args.n_min, args.n_max)]
    if args.period_fold:
        for lc in split:
            lc.period_fold()
    X_list = [np.c_[lc.times, lc.measurements, lc.errors] for lc in split]

    X_raw = pad_sequences(X_list, value=np.nan, dtype='float', padding='post')

    model_type_dict = {'gru': GRU, 'lstm': LSTM, 'vanilla': SimpleRNN}
    X, means, scales, wrong_units = preprocess(X_raw, args.m_max)
    main_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    aux_input = Input(shape=(X.shape[1], X.shape[-1] - 1), name='aux_input')
    model_input = [main_input, aux_input]
    encode = encoder(main_input,
                     layer=model_type_dict[args.model_type],
                     output_size=args.embedding,
                     **vars(args))
    decode = decoder(
        encode,
        num_layers=args.decode_layers
        if args.decode_layers else args.num_layers,
        layer=model_type_dict[args.decode_type if args.decode_type else args.
                              model_type],
        n_step=X.shape[1],
        aux_input=aux_input,
        **{k: v
           for k, v in vars(args).items() if k != 'num_layers'})
    model = Model(model_input, decode)

    run = ku.get_run_id(**vars(args))

    errors = X_raw[:, :, 2] / scales
    sample_weight = 1. / errors
    sample_weight[np.isnan(sample_weight)] = 0.0
    X[np.isnan(X)] = 0.

    history = ku.train_and_log(
        {
            'main_input': X,
            'aux_input': np.delete(X, 1, axis=2)
        },
        X[:, :, [1]],
        run,
        model,
        sample_weight=sample_weight,
        errors=errors,
        validation_split=0.0,
        **vars(args))

    return X, X_raw, model, means, scales, wrong_units, args
def main(args=None):
    args = parse_model_args(args)

    K.set_floatx('float64')

    run = get_run_id(**vars(args))
    log_dir = os.path.join(os.getcwd(), 'keras_logs', args.sim_type, run)
    weights_path = os.path.join(log_dir, 'weights.h5')

    if not os.path.exists(weights_path):
        raise FileNotFoundError(weights_path)

    X_fold, X_raw_fold, model, means, scales, wrong_units, args = survey_autoencoder(
        vars(args))

    print("log_dir", log_dir)
    print("Weight matrix read...")

    full = joblib.load(args.survey_files[0])

    # Combine subclasses
    # Resulting in five classes: RR, EC, SR, M, ROT
    for lc in full:
        if ((lc.label == 'EW') or (lc.label == 'EA') or (lc.label == 'EB')):
            lc.label = 'ECL'
        if ((lc.label == 'CWA') or (lc.label == 'CWB') or (lc.label == 'DCEP')
                or (lc.label == 'DCEPS') or (lc.label == 'RVA')):
            lc.label = "CEPH"
        if ((lc.label == 'DSCT') or (lc.label == 'HADS')):
            lc.label = "DSCT"
        if ((lc.label == 'RRD') or (lc.label == 'RRC')):
            lc.label = "RRCD"
    top_classes = ['SR', 'RRAB', 'RRCD', 'M', 'ROT', 'ECL', 'CEPH', 'DSCT']
    new_classes = ['VAR']

    top = [lc for lc in full if lc.label in top_classes]
    new = [lc for lc in full if lc.label in new_classes]

    # Preparation for classification probability
    classprob_pkl = joblib.load("./data/asassn/class_probs.pkl")
    class_probability = dict(classprob_pkl)

    if args.ss_resid:
        top = [lc for lc in top if lc.ss_resid <= args.ss_resid]

    if args.class_prob:
        top = [
            lc for lc in top
            if float(class_probability[lc.name.split("/")[-1][2:-4]]) >= 0.9
        ]
        #top = [lc for lc in top if lc.class_prob >= args.class_prob]

    split = [el for lc in top for el in lc.split(args.n_min, args.n_max)]
    split_new = [el for lc in new for el in lc.split(args.n_min, args.n_max)]

    if args.period_fold:
        for lc in split:
            lc.period_fold()
        for lc in split_new:
            if lc.p is not None:
                lc.period_fold()

    X_list = [np.c_[lc.times, lc.measurements, lc.errors] for lc in split]
    classnames, indices = np.unique([lc.label for lc in split],
                                    return_inverse=True)
    y = classnames[indices]
    periods = np.array([np.log10(lc.p) for lc in split])
    periods = periods.reshape(len(split), 1)

    X_raw = pad_sequences(X_list, value=0., dtype='float64', padding='post')
    X, means, scales, wrong_units, X_err = preprocess(X_raw, args.m_max)
    y = y[~wrong_units]
    periods = periods[~wrong_units]

    train, valid = list(
        StratifiedKFold(n_splits=5, shuffle=True,
                        random_state=SEED).split(X, y))[0]

    X_train = X[train]
    y_train = y[train]
    means_train = means[train]
    scales_train = scales[train]
    periods_train = periods[train]
    energy_dummy = np.zeros((X_train.shape[0], 1))

    X_valid = X[valid]
    y_valid = y[valid]
    means_valid = means[valid]
    scales_valid = scales[valid]
    periods_valid = periods[valid]
    energy_dummy_valid = np.zeros((X_valid.shape[0], 1))

    supports_train = np.concatenate((means_train, scales_train, periods_train),
                                    axis=1)
    supports_valid = np.concatenate((means_valid, scales_valid, periods_valid),
                                    axis=1)

    # New class data (VAR type)
    X_list_new = [
        np.c_[lc.times, lc.measurements, lc.errors] for lc in split_new
    ]
    classnames_new, indices_new = np.unique([lc.label for lc in split_new],
                                            return_inverse=True)
    y_new = classnames_new[indices_new]
    periods_new = np.array(
        [np.log10(lc.p) if lc.p is not None else 99.0 for lc in split_new])
    periods_new = periods_new.reshape(len(split_new), 1)

    X_raw_new = pad_sequences(X_list_new,
                              value=0.,
                              dtype='float64',
                              padding='post')
    X_new, means_new, scales_new, wrong_units_new, X_err_new = preprocess(
        X_raw_new, args.m_max)
    y_new = y_new[~wrong_units_new]
    periods_new = periods_new[~wrong_units_new]
    supports_new = np.concatenate((means_new, scales_new, periods_new), axis=1)

    ### Concatenating validation data and data from new classes for testing novelty detection
    y_new = np.concatenate((y_valid, y_new), axis=0)
    X_new = np.concatenate((X_valid, X_new), axis=0)
    supports_new = np.concatenate((supports_valid, supports_new), axis=0)

    num_supports_train = supports_train.shape[-1]
    num_supports_valid = supports_valid.shape[-1]
    assert (num_supports_train == num_supports_valid)
    num_additional = num_supports_train + 2  # 2 for reconstruction error

    ### flagging novel samples as 1, samples in superclasses 0.
    true_flags = np.array([1 if (l not in top_classes) else 0 for l in y_new])

    ### Making one-hot labels
    label_encoder1 = LabelEncoder()
    label_encoder1.fit(y_train)
    train_y_encoded = label_encoder1.transform(y_train)
    train_y = np_utils.to_categorical(train_y_encoded)

    label_encoder2 = LabelEncoder()
    label_encoder2.fit(y_valid)
    valid_y_encoded = label_encoder2.transform(y_valid)
    valid_y = np_utils.to_categorical(valid_y_encoded)

    label_encoder3 = LabelEncoder()
    label_encoder3.fit(y_new)
    new_y_encoded = label_encoder3.transform(y_new)
    new_y = np_utils.to_categorical(new_y_encoded)

    ### Loading trained autoencoder network
    encode_model = Model(inputs=model.input,
                         outputs=model.get_layer('encoding').output)
    decode_model = Model(inputs=model.input,
                         outputs=model.get_layer('time_dist').output)

    X_train = np.float64(X_train)
    X_valid = np.float64(X_valid)
    X_new = np.float64(X_new)

    # Passing samples through trained layers
    encoding_train = encode_model.predict({
        'main_input':
        X_train,
        'aux_input':
        np.delete(X_train, 1, axis=2),
        'support_input':
        supports_train
    })
    encoding_valid = encode_model.predict({
        'main_input':
        X_valid,
        'aux_input':
        np.delete(X_valid, 1, axis=2),
        'support_input':
        supports_valid
    })
    encoding_new = encode_model.predict({
        'main_input':
        X_new,
        'aux_input':
        np.delete(X_new, 1, axis=2),
        'support_input':
        supports_new
    })

    decoding_train = decode_model.predict({
        'main_input':
        X_train,
        'aux_input':
        np.delete(X_train, 1, axis=2),
        'support_input':
        supports_train
    })
    decoding_valid = decode_model.predict({
        'main_input':
        X_valid,
        'aux_input':
        np.delete(X_valid, 1, axis=2),
        'support_input':
        supports_valid
    })
    decoding_new = decode_model.predict({
        'main_input':
        X_new,
        'aux_input':
        np.delete(X_new, 1, axis=2),
        'support_input':
        supports_new
    })

    z_both_train = extract_features([
        X_train[:, :, 1], decoding_train[:, :, 0], encoding_train,
        supports_train
    ])
    z_both_valid = extract_features([
        X_valid[:, :, 1], decoding_valid[:, :, 0], encoding_valid,
        supports_valid
    ])
    z_both_new = extract_features(
        [X_new[:, :, 1], decoding_new[:, :, 0], encoding_new, supports_new])

    z_both_train = K.eval(z_both_train)
    z_both_valid = K.eval(z_both_valid)
    z_both_new = K.eval(z_both_new)

    # Retrieve estimation network if gmm is on
    if (args.gmm_on):
        estnet_model = Model(inputs=model.input,
                             outputs=model.get_layer('gamma').output)
        gamma_train = estnet_model.predict({
            'main_input':
            X_train,
            'aux_input':
            np.delete(X_train, 1, axis=2),
            'support_input':
            supports_train
        })
        gamma_valid = estnet_model.predict({
            'main_input':
            X_valid,
            'aux_input':
            np.delete(X_valid, 1, axis=2),
            'support_input':
            supports_valid
        })
    else:
        est_net = EstimationNet(args.estnet_size, args.num_classes, K.tanh)

    # Fit data to gmm if joint, train gmm if sequential
    gmm = GMM(args.num_classes, args.embedding + num_additional)
    gmm_init = gmm.init_gmm_variables()

    # If sequential training, create and train the estimation net
    if (not args.gmm_on):
        z_input = Input(shape=(args.embedding + num_additional, ),
                        name='z_input')
        gamma = est_net.inference(z_input, args.estnet_drop_frac)

        sigma_i = Lambda(lambda x: gmm.fit(x[0], x[1]),
                         output_shape=(args.num_classes,
                                       args.embedding + num_additional,
                                       args.embedding + num_additional),
                         name="sigma_i")([z_input, gamma])

        energy = Lambda(lambda x: gmm.energy(x[0], x[1]),
                        name="energy")([z_input, sigma_i])

        # Setting up the GMM model
        model_output = [gamma, energy]
        model = Model(z_input, model_output)

        optimizer = Adam(
            lr=args.lr if not args.finetune_rate else args.finetune_rate)
        model.compile(optimizer=optimizer,
                      loss=['categorical_crossentropy', energy_loss],
                      metrics={'gamma': 'accuracy'},
                      loss_weights=[1.0, args.lambda1])

        # Controlling outputs
        estnet_size = args.estnet_size
        estsize = "_".join(str(s) for s in estnet_size)
        gmm_run = 'estnet{}_estdrop{}_l1{}'.format(
            estsize, int(100 * args.estnet_drop_frac), args.lambda1)
        gmm_dir = os.path.join(os.getcwd(), 'keras_logs', args.sim_type, run,
                               gmm_run)
        if (not os.path.isdir(gmm_dir)):
            os.makedirs(gmm_dir)

        # For classifier training only
        args.nb_epoch = args.cls_epoch

        history = ku.train_and_log(z_both_train, {
            'gamma': train_y,
            'energy': energy_dummy
        },
                                   gmm_dir,
                                   model,
                                   validation_data=(z_both_valid, {
                                       'gamma':
                                       valid_y,
                                       'energy':
                                       energy_dummy_valid
                                   }, {
                                       'gamma': None,
                                       'energy': None
                                   }),
                                   **vars(args))

        gamma_train, energy_train = model.predict(z_both_train)
        gamma_valid, energy_valid = model.predict(z_both_valid)

    if (not args.gmm_on):
        log_dir = gmm_dir

    plot_dir = os.path.join(log_dir, 'figures/')
    if (not os.path.isdir(plot_dir)):
        os.makedirs(plot_dir)

    # Converting to Keras variables to use Keras.backend functions
    z_both_train_K = K.variable(z_both_train)
    gamma_train_K = K.variable(gamma_train)

    z_both_valid_K = K.variable(z_both_valid)
    gamma_valid_K = K.variable(gamma_valid)

    z_both_new_K = K.variable(z_both_new)

    # Fitting GMM parameters only with training set
    sigma_i_train = gmm.fit(z_both_train_K, gamma_train_K)
    sigma_i_dummy = 1.0
    assert (gmm.fitted == True)

    # Energy calculation
    energy_train = K.eval(gmm.energy(z_both_train_K, sigma_i_train))[:, 0]
    energy_valid = K.eval(gmm.energy(z_both_valid_K, sigma_i_dummy))[:, 0]
    energy_new = K.eval(gmm.energy(z_both_new_K, sigma_i_dummy))[:, 0]

    energy_known = [
        e for i, e in enumerate(energy_new) if (true_flags[i] == 0)
    ]
    energy_unknown = [
        e for i, e in enumerate(energy_new) if (true_flags[i] == 1)
    ]
    print("known/unknown", len(energy_known), len(energy_unknown))

    gmm_phi = K.eval(gmm.phi)
    gmm_mu = K.eval(gmm.mu)
    gmm_sigma = K.eval(gmm.sigma)

    np.savez(log_dir + '/gmm_parameters.npz',
             gmm_phi=gmm_phi,
             gmm_mu=gmm_mu,
             gmm_sigma=gmm_sigma)

    percentile_list = [80.0, 95.0]

    txtfilename = log_dir + "/novel_detection_scores.txt"
    txtfile = open(txtfilename, 'w')

    for per in percentile_list:
        new_class_energy_threshold = np.percentile(energy_train, per)
        print(
            f"Energy threshold to detect new class: {new_class_energy_threshold:.2f}"
        )

        new_pred_flag = np.where(energy_new >= new_class_energy_threshold, 1,
                                 0)

        prec, recall, fscore, _ = precision_recall_fscore_support(
            true_flags, new_pred_flag, average="binary")

        print(f"Detecting new using {per:.1f}% percentile")
        print(f" Precision = {prec:.3f}")
        print(f" Recall    = {recall:.3f}")
        print(f" F1-Score  = {fscore:.3f}")

        txtfile.write(f"Detecting new using {per:.1f}% percentile \n")
        txtfile.write(f" Precision = {prec:.3f}\n")
        txtfile.write(f" Recall    = {recall:.3f}\n")
        txtfile.write(f" F1-Score  = {fscore:.3f}\n")
    txtfile.close()

    ### Make plots of energy
    nbin = 100
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.hist(energy_train,
             nbin,
             normed=True,
             color='black',
             histtype='step',
             label='Training Set')
    plt.hist(np.isfinite(energy_unknown),
             nbin,
             normed=True,
             color='blue',
             histtype='step',
             label='Unknown Classes')
    plt.hist(energy_known,
             nbin,
             normed=True,
             color='green',
             histtype='step',
             label='Known Classes')
    plt.legend()
    plt.xlabel(r"Energy E(z)")
    plt.ylabel("Probability")
    plt.savefig(plot_dir + 'energy_histogram.pdf',
                dpi=300,
                bbox_inches='tight')
    plt.clf()
    plt.cla()
    plt.close()

    ### Generate confusion matrix
    le_list = list(label_encoder2.classes_)
    predicted_onehot = gamma_valid
    predicted_labels = [
        le_list[np.argmax(onehot, axis=None, out=None)]
        for onehot in predicted_onehot
    ]

    corr_num, tot_num = plot_confusion_matrix(
        y_valid, predicted_labels, classnames,
        plot_dir + 'asassn_nn_confusion.pdf')
    nn_acc = corr_num / tot_num

    ### Generate confusion matrix for RF
    RF_PARAM_GRID = {
        'n_estimators': [50, 100, 250],
        'criterion': ['gini', 'entropy'],
        'max_features': [0.05, 0.1, 0.2, 0.3],
        'min_samples_leaf': [1, 2, 3]
    }
    rf_model = GridSearchCV(RandomForestClassifier(random_state=0),
                            RF_PARAM_GRID)
    rf_model.fit(encoding_train, y[train])

    rf_train_acc = 100 * rf_model.score(encoding_train, y[train])
    rf_valid_acc = 100 * rf_model.score(encoding_valid, y[valid])

    plot_confusion_matrix(y[valid], rf_model.predict(encoding_valid),
                          classnames, plot_dir + 'asassn_rf_confusion.pdf')

    ### Text output
    txtfilename = log_dir + "/classification_accuracy.txt"
    txtfile = open(txtfilename, 'w')

    ### Writing results
    txtfile.write("===== Classification Accuracy =====\n")
    txtfile.write(f"Neural Network Classifier: {nn_acc:.2f}\n")
    txtfile.write("==========================\n")
    txtfile.write("Random Forest Classifier\n")
    txtfile.write(f"Training accuracy: {rf_train_acc:2.2f}%\n")
    txtfile.write(f"Validation accuracy: {rf_valid_acc:2.2f}%\n")
    txtfile.write(f"Best RF {rf_model.best_params_}\n")
    txtfile.write("==========================\n")
    txtfile.close()