Beispiel #1
0
def main():
    args = args_parse()
    enc_path = args.enc
    dis_path = args.dis

    device = torch.device('cuda')

    # Init models
    Encoder = encoder(1, 1, 16).to(device)
    Encoder.load_state_dict(torch.load(enc_path, map_location=device))
    Encoder.eval()

    Discriminator = discriminator().to(device)
    Discriminator.load_state_dict(torch.load(dis_path, map_location=device))
    Discriminator.eval()

    models = {'enc': Encoder, 'dis': Discriminator}

    # Init data
    val_set = ppmi()

    # Extract features
    ft_vectors, subject_list, im_list = extract_features(
        models, val_set, device)
    distance_matrix = get_dist(ft_vectors)
    rel_matrix = get_rel(subject_list, im_list)
    # Compute similarity and rank similarity
    subject_dict = get_subject_dict(im_list)
    # Compute sensitivity, precision and d-prime
    MAP = compute_MAP(distance_matrix, im_list, subject_dict, rel_matrix, 5)
    print('MAP:{}'.format(MAP))
Beispiel #2
0
def compute_ms_ssim(dataset, encryptor=None, device=None):
    if encryptor != None:
        Encryptor = encoder(1, 1, 16).to(device)
        Encryptor.load_state_dict(torch.load(encryptor, map_location=device))
        Encryptor.eval()

    loader = data.DataLoader(dataset,
                             batch_size=1,
                             num_workers=16,
                             pin_memory=True,
                             shuffle=True)

    pos_ms_ssim = []
    neg_ms_ssim = []
    start = time()
    for step, (x, x_ref, y, y_ref, d, d_p, d_n, im,
               im_ref) in enumerate(loader):
        if encryptor != None:
            with torch.no_grad():
                x, x_ref = x.to(device), x_ref.to(device)
                x = Encryptor(x).cpu()
                x_ref = Encryptor(x_ref).cpu()

        x = x[0, 0, :, :, :].numpy().astype(float)
        x_ref = x_ref[0, 0, :, :, :].numpy().astype(float)
        ms_ssim_score = ssim(x, x_ref)
        if d.item() == 1:
            pos_ms_ssim += [ms_ssim_score]
        elif d.item() == 0:
            neg_ms_ssim += [ms_ssim_score]
    dur = time() - start
    print('|- Compute ms-ssim scores, duration: {:.0f} sec'.format(dur))
    return pos_ms_ssim, neg_ms_ssim
def main():
	parser = argparse.ArgumentParser()
	parser.add_argument('--encoder', type=str)
	parser.add_argument('--segmentation', type=str)
	parser.add_argument('--discriminator', type=str)
	parser.add_argument('--dataset', type=str, default='ppmi')
	args = parser.parse_args()
	
	encoder_path = args.encoder
	discriminator_path = args.discriminator
	segmentation_path = args.segmentation
	
	batch_size = 16
	device = torch.device('cuda')
	if args.dataset == 'ppmi':
		val_set = ppmi_pairs(mode = 'val')
	#if args.dataset == 'iseg':
	#	val_set = iseg_pairs()

	'''Initialize networks'''	
	# init model encrypter
	Encrypter = encoder(1,1,16).to(device)
	Encrypter.load_state_dict(torch.load(encoder_path, map_location=device))
	print('Load:{}'.format(encoder_path))
	
	# init discriminator
	Discriminator = discriminator().to(device)
	Discriminator.load_state_dict(torch.load(discriminator_path, map_location=device))
	print('Load:{}'.format(discriminator_path))

	# init segmentator
	Segmentator = segnet(1,6,32).to(device)
	Segmentator.load_state_dict(torch.load(segmentation_path, map_location=device))
	print('Load:{}'.format(segmentation_path))
	
	#
	models = {'enc': Encrypter, 'seg': Segmentator, 'dis': Discriminator}

	# declare loss function
	Segment_criterion = Dice_Loss()
	Discrimination_criterion = torch.nn.CrossEntropyLoss()
	criterions = {'seg': Segment_criterion, 'dis': Discrimination_criterion}

	seg_loss, adv_loss, dis_acc, dice_score = val_epoch(models, criterions, val_set, batch_size, device)
Beispiel #4
0
def main(args=None):
    args = ku.parse_model_args(args)

    N = args.N_train + args.N_test
    train = np.arange(args.N_train)
    test = np.arange(args.N_test) + args.N_train
    X, Y, X_raw = sample_data.periodic(N,
                                       args.n_min,
                                       args.n_max,
                                       even=args.even,
                                       noise_sigma=args.sigma,
                                       kind=args.data_type)

    if args.even:
        X = X[:, :, 1:2]
    else:
        X[:, :, 0] = ku.times_to_lags(X_raw[:, :, 0])
        X[np.isnan(X)] = -1.
        X_raw[np.isnan(X_raw)] = -1.

    Y = sample_data.phase_to_sin_cos(Y)
    scaler = StandardScaler(copy=False, with_mean=True, with_std=True)
    scaler.fit_transform(Y)
    if args.loss_weights:  # so far, only used to zero out some columns
        Y *= args.loss_weights

    model_type_dict = {'gru': GRU, 'lstm': LSTM, 'vanilla': SimpleRNN}

    model_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    encode = encoder(model_input,
                     layer=model_type_dict[args.model_type],
                     output_size=Y.shape[-1],
                     **vars(args))
    model = Model(model_input, encode)

    run = ku.get_run_id(**vars(args))

    history = ku.train_and_log(X[train], Y[train], run, model, **vars(args))
    return X, Y, X_raw, scaler, model, args
Beispiel #5
0
def get_model(sess, image_shape=(80, 160, 3), gf_dim=64, df_dim=64, batch_size=64,
              name="transition", gpu=0):
    K.set_session(sess)
    checkpoint_dir = './results_' + name
    with tf.variable_scope(name):
      # sizes
      ch = image_shape[2]
      rows = [image_shape[0]/i for i in [16, 8, 4, 2, 1]]
      cols = [image_shape[1]/i for i in [16, 8, 4, 2, 1]]

      G = autoencoder.generator(batch_size*out_leng, gf_dim, ch, rows, cols)
      G.compile("sgd", "mse")
      E = autoencoder.encoder(batch_size*(time+out_leng), df_dim, ch, rows, cols)
      E.compile("sgd", "mse")

      G.trainable = False
      E.trainable = False

      # nets
      T = transition(batch_size)
      T.compile("sgd", "mse")
      t_vars = T.trainable_weights
      print "T.shape: ", T.output_shape

      Img = Input(batch_shape=(batch_size, time+out_leng,) + image_shape)
      Z = Input(batch_shape=(batch_size, time+out_leng, 2))  # controls signal
      I = K.reshape(Img, (batch_size*(time+out_leng),)+image_shape)
      code = E(I)[0]
      code = K.reshape(code, (batch_size, time+out_leng, z_dim))
      inp = K.concatenate([Z, code], axis=2)
      target = code[:, time:, :]
      out = T(inp)
      G_dec = G(K.reshape(out, (batch_size*out_leng, z_dim)))

      # costs
      loss = tf.reduce_mean(tf.square(target - out))
      print "Transition variables:"
      for v in t_vars:
        print v.name

      t_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(loss, var_list=t_vars)

      tf.initialize_all_variables().run()

    # summaries
    sum_loss = tf.scalar_summary("loss", loss)
    sum_e_mean = tf.histogram_summary("e_mean", code)
    sum_out = tf.histogram_summary("out", out)
    sum_dec = tf.image_summary("E", G_dec)

    # saver
    saver = tf.train.Saver()
    t_sum = tf.merge_summary([sum_e_mean, sum_out, sum_dec, sum_loss])
    writer = tf.train.SummaryWriter("/tmp/logs/"+name, sess.graph)

    # functions
    def train_d(images, z, counter, sess=sess):
      return 0, 0, 0

    def train_g(images, z, counter, sess=sess):
      outputs = [loss, G_dec, t_sum, t_optim]
      outs = sess.run(outputs, feed_dict={Img: images, Z: z, K.learning_phase(): 1})
      gl, samples, sums = outs[:3]
      writer.add_summary(sums, counter)
      images = images[:, time:].reshape((-1, 80, 160, 3))[:64]
      samples = samples.reshape((-1, 80, 160, 3))[:64]
      return gl, samples, images

    def f_load():
      try:
        return load(sess, saver, checkpoint_dir, name)
      except:
        print("Loading weights via Keras")
        T.load_weights(checkpoint_dir+"/T_weights.keras")

    def f_save(step):
      save(sess, saver, checkpoint_dir, step, name)
      T.save_weights(checkpoint_dir+"/T_weights.keras", True)

    def sampler(z, x):
      video = np.zeros((128, 80, 160, 3))
      print "Sampling..."
      for i in range(128):
        print i
        x = x.reshape((-1, 80, 160, 3))
        # code = E.predict(x, batch_size=batch_size*(time+1))[0]
        code = sess.run([E(I)[0]], feed_dict={I: x, Z: z, K.learning_phase(): 1})[0]
        code = code.reshape((batch_size, time+out_leng, z_dim))
        inp = np.concatenate([z, code], axis=-1)
        outs = T.predict(inp, batch_size=batch_size)  # [:, :out_leng, :]
        # imgs = G.predict(out, batch_size=batch_size)
        imgs = sess.run([G_dec], feed_dict={out: outs, Z: z, K.learning_phase(): 1})[0]
        video[i] = imgs[0]
        x = x.reshape((batch_size, time+out_leng, 80, 160, 3))
        x[0, :-1] = x[0, 1:]
        x[0, -1] = imgs[0]
        z[0, :-1] = z[0, 1:]
      video = video.reshape((batch_size, 2, 80, 160, 3))
      return video[:, 0], video[:, 1]

    G.load_weights(G_file_path)
    E.load_weights(E_file_path)

    return train_g, train_d, sampler, f_save, f_load, [G, E, T]
        (-1, settings['image_size'][0], settings['image_size'][1],
         settings['num_channels'])).astype(np.float32)
    train_labels = to_categorical(train_labels)
    valid_dataset = valid_dataset.reshape(
        (-1, settings['image_size'][0], settings['image_size'][1],
         settings['num_channels'])).astype(np.float32)
    valid_labels_oh = to_categorical(valid_labels)

    print('Training set', train_dataset.shape, train_labels.shape)
    print('Validation set', valid_dataset.shape, valid_labels_oh.shape)

    input_img = Input(shape=(settings['image_size'][0],
                             settings['image_size'][1],
                             settings['num_channels']))

    encode = encoder(input_img)
    full_model = Model(input_img, fc(encode))

    autoencoder = Model(input_img, decoder(encoder(input_img)))
    autoencoder.load_weights('autoencoder.h5')
    plot_model(autoencoder,
               to_file='./img/autoencoder.eps',
               show_shapes=True,
               show_layer_names=False)

    for l1, l2 in zip(full_model.layers[:12], autoencoder.layers[:12]):
        l1.set_weights(l2.get_weights())

    for layer in full_model.layers[:12]:
        layer.trainable = False
def main(args=None):
    """Train an autoencoder model from `LightCurve` objects saved in
    `args.survey_files`.
    
    args: dict
        Dictionary of values to override default values in `keras_util.parse_model_args`;
        can also be passed via command line. See `parse_model_args` for full list of
        possible arguments.
    """
    args = ku.parse_model_args(args)

    np.random.seed(0)

    K.set_floatx('float64')

    run = ku.get_run_id(**vars(args))

    if not args.survey_files:
        raise ValueError("No survey files given")

    lc_lists = [joblib.load(f) for f in args.survey_files]
    n_reps = [max(len(y) for y in lc_lists) // len(x) for x in lc_lists]
    combined = sum([x * i for x, i in zip(lc_lists, n_reps)], [])

    # Preparation for classification probability
    classprob_pkl = joblib.load("./data/asassn/class_probs.pkl")
    class_probability = dict(classprob_pkl)

    # Combine subclasses into eight superclasses:
    # CEPH, DSCT, ECL, RRAB, RRCD, M, ROT, SR
    for lc in combined:
      if ((lc.label == 'EW') or (lc.label == 'EA') or (lc.label == 'EB')):
        lc.label = 'ECL'
      if ((lc.label == 'CWA') or (lc.label == 'CWB') or (lc.label == 'DCEP') or
          (lc.label == 'DCEPS') or (lc.label == 'RVA')):
        lc.label = "CEPH"
      if ((lc.label == 'DSCT') or (lc.label == 'HADS')):
        lc.label = "DSCT"
      if ((lc.label == 'RRD') or (lc.label == 'RRC')):
        lc.label = "RRCD"
    top_classes = ['SR', 'RRAB', 'RRCD', 'M', 'ROT', 'ECL', 'CEPH', 'DSCT']

    print ("Number of raw LCs:", len(combined))

    if args.lomb_score:
        combined = [lc for lc in combined if lc.best_score >= args.lomb_score]
    if args.ss_resid:
        combined = [lc for lc in combined if lc.ss_resid <= args.ss_resid]
    if args.class_prob:
        combined = [lc for lc in combined if float(class_probability[lc.name.split("/")[-1][2:-4]]) >= args.class_prob]
        #combined = [lc for lc in combined if lc.class_prob >= args.class_prob]

    # Select only superclasses for training
    combined = [lc for lc in combined if lc.label in top_classes]

    split = [el for lc in combined for el in lc.split(args.n_min, args.n_max)]
    if args.period_fold:
        for lc in split:
            lc.period_fold()

    X_list = [np.c_[lc.times, lc.measurements, lc.errors] for lc in split]
    classnames, indices = np.unique([lc.label for lc in split], return_inverse=True)
    y = classnames[indices]
    periods = np.array([np.log10(lc.p) for lc in split])
    periods = periods.reshape(len(split), 1)

    X_raw = pad_sequences(X_list, value=np.nan, dtype='float64', padding='post')

    model_type_dict = {'gru': GRU, 'lstm': LSTM, 'vanilla': SimpleRNN}
    X, means, scales, wrong_units, X_err = preprocess(X_raw, args.m_max)

    y = y[~wrong_units]
    periods = periods[~wrong_units]

    # Prepare the indices for training and validation
    train, valid = list(StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED).split(X, y))[0]

    X_valid = X[valid]
    y_valid = y[valid]
    means_valid = means[valid]
    scales_valid = scales[valid]
    periods_valid = periods[valid]
    energy_dummy_valid = np.zeros((X_valid.shape[0], 1))
    X_err_valid = X_err[valid]
    sample_weight_valid = 1. / X_err_valid

    X = X[train]
    y = y[train]
    means = means[train]
    scales = scales[train]
    periods = periods[train]
    energy_dummy = np.zeros((X.shape[0], 1))
    X_err = X_err[train]
    sample_weight = 1. / X_err

    supports_valid = np.concatenate((means_valid, scales_valid, periods_valid), axis=1)
    supports = np.concatenate((means, scales, periods), axis=1)

    num_supports_train = supports.shape[-1]
    num_supports_valid = supports_valid.shape[-1]
    assert(num_supports_train == num_supports_valid)
    num_additional = num_supports_train + 2 # 2 for reconstruction error

    if (args.gmm_on):
      gmm = GMM(args.num_classes, args.embedding+num_additional)

    ### Covert labels into one-hot vectors for training
    label_encoder = LabelEncoder()
    label_encoder.fit(y)
    ### Transform the integers into one-hot vector for softmax
    train_y_encoded = label_encoder.transform(y)
    train_y   = np_utils.to_categorical(train_y_encoded)
    ### Repeat for validation dataset 
    label_encoder = LabelEncoder()
    label_encoder.fit(y_valid)
    ### Transform the integers into one-hot vector for softmax
    valid_y_encoded = label_encoder.transform(y_valid)
    valid_y   = np_utils.to_categorical(valid_y_encoded)

    main_input = Input(shape=(X.shape[1], 2), name='main_input') # dim: (200, 2) = dt, mag
    aux_input  = Input(shape=(X.shape[1], 1), name='aux_input') # dim: (200, 1) = dt's

    model_input = [main_input, aux_input]
    if (args.gmm_on):
      support_input = Input(shape=(num_supports_train,), name='support_input') 
      model_input = [main_input, aux_input, support_input]

    encode = encoder(main_input, layer=model_type_dict[args.model_type], 
                     output_size=args.embedding, **vars(args))
    decode = decoder(encode, num_layers=args.decode_layers if args.decode_layers
                                                           else args.num_layers,
                     layer=model_type_dict[args.decode_type if args.decode_type
                                           else args.model_type],
                     n_step=X.shape[1], aux_input=aux_input,
                     **{k: v for k, v in vars(args).items() if k != 'num_layers'})
    optimizer = Adam(lr=args.lr if not args.finetune_rate else args.finetune_rate)

    if (not args.gmm_on):
      model = Model(model_input, decode)

      model.compile(optimizer=optimizer, loss='mse',
                    sample_weight_mode='temporal')

    else: 
      est_net = EstimationNet(args.estnet_size, args.num_classes, K.tanh)

      # extract x_i and hat{x}_{i} for reconstruction error feature calculations
      main_input_slice = Lambda(lambda x: x[:,:,1], name="main_slice")(main_input)
      decode_slice = Lambda(lambda x: x[:,:,0], name="decode_slice")(decode)
      z_both = Lambda(extract_features, name="concat_zs")([main_input_slice, decode_slice, encode, support_input])

      gamma = est_net.inference(z_both, args.estnet_drop_frac)
      print ("z_both shape", z_both.shape)
      print ("gamma shape", gamma.shape)

      sigma_i = Lambda(lambda x: gmm.fit(x[0], x[1]), 
                       output_shape=(args.num_classes, args.embedding+num_additional, args.embedding+num_additional), 
                       name="sigma_i")([z_both, gamma])

      energy = Lambda(lambda x: gmm.energy(x[0], x[1]), name="energy")([z_both, sigma_i])

      model_output = [decode, gamma, energy]

      model = Model(model_input, model_output)

      optimizer = Adam(lr=args.lr if not args.finetune_rate else args.finetune_rate)

      model.compile(optimizer=optimizer, 
                    loss=['mse', 'categorical_crossentropy', energy_loss],
                    loss_weights=[1.0, 1.0, args.lambda1],
                    metrics={'gamma':'accuracy'},
                    sample_weight_mode=['temporal', None, None])

      # summary for checking dimensions
      print(model.summary())


    if (args.gmm_on): 
      history = ku.train_and_log( {'main_input':X, 'aux_input':np.delete(X, 1, axis=2), 'support_input':supports},
                                  {'time_dist':X[:, :, [1]], 'gamma':train_y, 'energy':energy_dummy}, run, model, 
                                  sample_weight={'time_dist':sample_weight, 'gamma':None, 'energy':None},
                                  validation_data=(
                                  {'main_input':X_valid, 'aux_input':np.delete(X_valid, 1, axis=2), 'support_input':supports_valid}, 
                                  {'time_dist':X_valid[:, :, [1]], 'gamma':valid_y, 'energy':energy_dummy_valid},
                                  {'time_dist':sample_weight_valid, 'gamma':None, 'energy':None}), 
                                  **vars(args))
    else: # Just train RNN autoencoder
      history = ku.train_and_log({'main_input':X, 'aux_input':np.delete(X, 1, axis=2)},
                                  X[:, :, [1]], run, model,
                                  sample_weight=sample_weight,
                                  validation_data=(
                                  {'main_input':X_valid, 'aux_input':np.delete(X_valid, 1, axis=2)},
                                  X_valid[:, :, [1]], sample_weight_valid), **vars(args))

    return X, X_raw, model, means, scales, wrong_units, args
Beispiel #8
0
def main(argv):

    #Get command line arguments
    inputFile, inputLabels, testFile, testLabels, encModel, readUpto, testUpto, outputFile = getCommandLineArgs(
        argv)

    #Get hyperparameters
    numConvLay, sizeConvFil, numConvFilPerLay, numNodes, epochs, batchSize = getExecParameters(
    )

    # Open files
    inputFile = open(inputFile, "rb")
    inputLabels = open(inputLabels, "rb")
    testFile = open(testFile, "rb")
    testLabels = open(testLabels, "rb")

    # Get headers
    magicnumber1, numImages, numrows, numcols = getIdxHeaders(inputFile)
    magicnumber2, numImages2 = getIdxHeaders2(inputLabels)
    magicnumber3, numImagesTest, numrowsTrain, numcolsTrain = getIdxHeaders(
        testFile)
    magicnumber4, numImagesTest2 = getIdxHeaders2(testLabels)

    # Small check
    if numImages != numImages2:
        print("Size of train file and label file do not match!")
        sys.exit()

    # Small check
    if numImagesTest != numImagesTest2:
        print("Size of train file and label file do not match!")
        sys.exit()

    # Small check
    if (numrows != numrowsTrain) or (numcols != numcolsTrain):
        print(
            "Image size in training file is not equal to image size in test file!"
        )
        sys.exit()
    '''
    print(magicnumber1, numImages, numrows, numcols)
    print(magicnumber2, numImages2)
    print(magicnumber3, numImagesTest, numrowsTrain, numcolsTrain)
    print(magicnumber4, numImagesTest2)
    '''

    # Small check
    if (readUpto is None) or (readUpto > numImages):
        readUpto = numImages

    # Small check
    if (testUpto is None) or (testUpto > numImagesTest):
        testUpto = numImagesTest

    #
    # Scan images from training set
    #

    img = Image(numrows, numcols)
    images = []
    train_images = np.zeros([readUpto, numrows, numcols])

    print("Scanning training images...")

    for i in range(readUpto):
        img.scan(inputFile, i)
        images.append(copy.deepcopy(img))
        train_images[i] = copy.deepcopy(img.pixels)

    train_images_original = train_images
    print(readUpto, "training images scanned!")

    #
    # Scan labels from training set
    #

    train_labels = np.zeros(readUpto, dtype=np.uint8)

    print("Scanning training labels...")

    for i in range(readUpto):
        b = inputLabels.read(1)
        train_labels[i] = int(int.from_bytes(b, "big"))

    print(readUpto, "training labels scanned!")
    '''
    # Print
    for i in range(readUpto):
        print("Label =", train_labels[i])
        images[i].print()
        print('')
    '''

    # Convert to one-hot encoding
    one_hot_train_labels = keras.utils.to_categorical(train_labels)

    #
    # Scan images from test set
    #

    img = Image(numrows, numcols)
    images2 = []
    test_images = np.zeros([testUpto, numrows, numcols])

    print("Scanning testing images...")

    for i in range(testUpto):
        img.scan(testFile, i)
        images2.append(copy.deepcopy(img))
        test_images[i] = copy.deepcopy(img.pixels)

    print(readUpto, "testing images scanned!")

    #
    # Scan labels from test set
    #

    test_labels = np.zeros(testUpto, dtype=np.uint8)

    print("Scanning testing labels...")

    for i in range(testUpto):
        b = testLabels.read(1)
        test_labels[i] = int(int.from_bytes(b, "big"))

    print(readUpto, "testing labels scanned!")
    '''
    # Print
    for i in range(testUpto):
        print("Label =", test_labels[i])
        images2[i].print()
        print('')
    '''

    # Convert to one-hot encoding
    one_hot_test_labels = keras.utils.to_categorical(test_labels)

    #
    # Set up classificator model
    #

    # Define input shape
    input_img = keras.Input(shape=(numrows, numcols, 1))

    # Load autoencoder weights
    tempModel = models.Model(
        input_img,
        decoder(encoder(input_img, numConvLay, sizeConvFil, numConvFilPerLay),
                numConvLay, sizeConvFil, numConvFilPerLay))
    tempModel.compile(loss='mean_squared_error',
                      optimizer=optimizers.RMSprop())
    tempModel.load_weights(encModel)

    # Define classificatior model
    classifier = models.Model(
        input_img,
        fullyConnected(
            encoder(input_img, numConvLay, sizeConvFil, numConvFilPerLay),
            numNodes))

    # Set weights in new model
    trainedLayers = numConvLay * (2 * numConvFilPerLay + 1) + 1
    for l1, l2 in zip(classifier.layers[:trainedLayers],
                      tempModel.layers[0:trainedLayers]):
        l1.set_weights(l2.get_weights())
    '''
    # Check validity of weights
    print("\n\nWeights #1")
    print(tempModel.get_weights()[0][1])
    print("\n\nWeights #2")
    print(classifier.get_weights()[0][1])
    '''

    # Set already trained weights to non trainable
    for layer in classifier.layers[0:trainedLayers]:
        layer.trainable = False

    # Compile classifier
    classifier.compile(loss=losses.categorical_crossentropy,
                       optimizer=optimizers.Adam(),
                       metrics=['accuracy'])

    # Print summary
    print(classifier.summary())

    # Split train data
    from sklearn.model_selection import train_test_split
    train_images, validation_images, train_labels, validation_labels = train_test_split(
        train_images, one_hot_train_labels, test_size=0.2, random_state=13)

    # Train the model
    history = classifier.fit(train_images,
                             train_labels,
                             batch_size=batchSize,
                             epochs=epochs,
                             verbose=1,
                             validation_data=(validation_images,
                                              validation_labels))
    '''
    # Plot
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    plt.show()
    '''

    # Re train the whole model

    for layer in classifier.layers[0:trainedLayers]:
        layer.trainable = True

    classifier.compile(loss=losses.categorical_crossentropy,
                       optimizer=optimizers.Adam(),
                       metrics=['accuracy'])

    history2 = classifier.fit(train_images,
                              train_labels,
                              batch_size=batchSize,
                              epochs=epochs,
                              verbose=1,
                              validation_data=(validation_images,
                                               validation_labels))

    # Plot
    plt.plot(history2.history['loss'])
    plt.plot(history2.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    plt.show()

    # Predict some stuff
    predictions = classifier.predict(test_images)
    predictions = np.argmax(np.round(predictions), axis=1)

    # Print stats
    print('')
    from sklearn.metrics import classification_report
    target_names = ["Class {}".format(i) for i in range(10)]
    print(
        classification_report(test_labels,
                              predictions,
                              target_names=target_names))

    ans = input("Print some extra stats? (y/n) ...")

    if ans == "y":

        # Calculate correct labels
        correct = np.where(predictions == test_labels, predictions, -1)
        asdf = np.where(correct != -1)[0]
        print("\nFound", len(asdf), " correct labels")

        # Print first 10 correct
        j = 0
        print("Printing first 10 correct...")
        for i, c in enumerate(correct):
            if c != -1:
                print("correct[i] =", correct[i], ", test_labels[i] =",
                      test_labels[i])
                j += 1
            if j == 10:
                break

        # Calculate incorrect labels
        incorrect = np.where(predictions != test_labels, predictions, -1)
        asdf2 = np.where(incorrect != -1)[0]
        print("\nFound", len(asdf2), " incorrect labels")

        # Print first 10 icorrect
        j = 0
        print("Printing first 10 incorrect...")
        for i, c in enumerate(incorrect):
            if c != -1:
                print("incorrect[i] =", incorrect[i], ", test_labels[i] =",
                      test_labels[i])
                j += 1
            if j == 10:
                break

    predictions = classifier.predict(train_images_original)
    predictions = np.argmax(np.round(predictions), axis=1)
    write_predictions(predictions, outputFile)
Beispiel #9
0
def main(args):
    X_train, Y_train = load_data(args.data)
    Y_train_ = np.stack([Y_train.squeeze()] * 3, axis=3)
    # X_train_bw = (np.sum(X_train, axis=3)/(3)).reshape([-1,245,437,1])

    if args.GPU == 0:
        config = tf.ConfigProto(device_count={'GPU': 0})
    else:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = False

    tf.reset_default_graph()
    if args.data == 0:
        X = tf.placeholder(tf.float32, [None, 480, 640, 3])
        Y = tf.placeholder(tf.float32, [None, 480, 640, 1])
    elif args.data == 1:
        X = tf.placeholder(tf.float32, [None, 245, 437, 3])
        X_ = tf.placeholder(tf.float32, [None, 245, 437, 1])
        Y = tf.placeholder(tf.float32, [None, 245, 437, 1])
    is_training = tf.placeholder(tf.bool)

    with tf.variable_scope('Encoder') as enc:
        latent_y = encoder(X, is_training, args.data)
    with tf.variable_scope('Decoder') as dec:
        output = decoder(latent_y, is_training, args.data)
    with tf.variable_scope('Loss_Encoder') as ln:
        y_feats = autoencoder.encoder(output, is_training, args.data)
    with tf.variable_scope(ln, reuse=True):
        x_feats = autoencoder.encoder(Y, is_training, args.data)

    trans_loss = l1_norm(output - Y)
    reg_loss = TV_loss(output)
    feat_loss = gram_loss(y_feats[0], x_feats[0])
    feat_loss += gram_loss(y_feats[1], x_feats[1])
    feat_loss += gram_loss(y_feats[2], x_feats[2])
    feat_loss += gram_loss(y_feats[3], x_feats[3])

    # feat_loss = tf.nn.l2_loss(y_feats[0] - x_feats[0])
    # feat_loss += tf.nn.l2_loss(y_feats[1] - x_feats[1])
    # feat_loss += tf.nn.l2_loss(y_feats[2] - x_feats[2])
    # feat_loss += tf.nn.l2_loss(y_feats[3] - x_feats[3])

    # mean_loss = tf.reduce_mean(trans_loss + args.lam*feat_loss)
    mean_loss = tf.reduce_mean(trans_loss + 0.1 * reg_loss + 10.0 * feat_loss)
    tf.summary.scalar('loss', mean_loss)

    optimizer = tf.train.AdamOptimizer(learning_rate=args.rate)
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    enc_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope='Encoder')
    dec_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope='Decoder')
    loss_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope='Loss_Encoder')
    with tf.control_dependencies(extra_update_ops):
        train_full = optimizer.minimize(mean_loss,
                                        var_list=[enc_vars, dec_vars])

    sess = tf.Session(config=config)
    enc_saver = tf.train.Saver(var_list=enc_vars)
    dec_saver = tf.train.Saver(var_list=dec_vars)

    loss_saver = tf.train.Saver(var_list=loss_vars)
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter('./tb', sess.graph)

    sess.run(tf.global_variables_initializer())
    loss_saver.restore(sess, './loss_network/loss_network_enc')
    enc_saver.restore(sess, './trained_model/initial_model_enc')
    dec_saver.restore(sess, './trained_model/initial_model_dec')

    _ = run_model(sess,
                  X,
                  X_,
                  Y,
                  is_training,
                  mean_loss,
                  X_train,
                  Y_train,
                  Y_train,
                  epochs=args.epochs,
                  batch_size=args.batch_size,
                  print_every=10,
                  training=train_full,
                  plot_losses=False,
                  writer=writer,
                  sum_vars=merged)

    model_name = './trained_model/final_model'
    # model_name += 'data_' + str(args.data)
    # model_name += '_epochs_' + str(args.epochs)
    # model_name += '_batchsize_' + str(args.batch_size)
    # model_name += '_rate_' + str(args.rate)
    enc_saver.save(sess, model_name + '_enc')
    dec_saver.save(sess, model_name + '_dec')
Beispiel #10
0
def main(args=None):
    args = ku.parse_model_args(args)

    args.loss = 'categorical_crossentropy'

    np.random.seed(0)

    if not args.survey_files:
        raise ValueError("No survey files given")
    classes = [
        'RR_Lyrae_FM', 'W_Ursae_Maj', 'Classical_Cepheid', 'Beta_Persei',
        'Semireg_PV'
    ]
    lc_lists = [joblib.load(f) for f in args.survey_files]
    combined = [lc for lc_list in lc_lists for lc in lc_list]
    combined = [lc for lc in combined if lc.label in classes]
    if args.lomb_score:
        combined = [lc for lc in combined if lc.best_score >= args.lomb_score]
    split = [el for lc in combined for el in lc.split(args.n_min, args.n_max)]
    if args.period_fold:
        for lc in split:
            lc.period_fold()
    X_list = [np.c_[lc.times, lc.measurements, lc.errors] for lc in split]

    classnames, y_inds = np.unique([lc.label for lc in split],
                                   return_inverse=True)
    Y = to_categorical(y_inds, len(classnames))

    X_raw = pad_sequences(X_list, value=np.nan, dtype='float', padding='post')
    X, means, scales, wrong_units = preprocess(X_raw, args.m_max)
    Y = Y[~wrong_units]

    # Remove errors
    X = X[:, :, :2]

    if args.even:
        X = X[:, :, 1:]


#    shuffled_inds = np.random.permutation(np.arange(len(X)))
#    train = np.sort(shuffled_inds[:args.N_train])
#    valid = np.sort(shuffled_inds[args.N_train:])
    train, valid = list(
        StratifiedKFold(n_splits=5, shuffle=True,
                        random_state=0).split(X_list, y_inds))[0]

    model_type_dict = {
        'gru': GRU,
        'lstm': LSTM,
        'vanilla': SimpleRNN,
        'conv': Conv1D
    }  #, 'atrous': AtrousConv1D, 'phased': PhasedLSTM}

    #    if args.pretrain:
    #        auto_args = {k: v for k, v in args.__dict__.items() if k != 'pretrain'}
    #        auto_args['sim_type'] = args.pretrain
    ##        auto_args['no_train'] = True
    #        auto_args['epochs'] = 1; auto_args['loss'] = 'mse'; auto_args['batch_size'] = 32; auto_args['sim_type'] = 'test'
    #        _, _, auto_model, _ = survey_autoencoder(auto_args)
    #        for layer in auto_model.layers:
    #            layer.trainable = False
    #        model_input = auto_model.input[0]
    #        encode = auto_model.get_layer('encoding').output
    #    else:
    #        model_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    #        encode = encoder(model_input, layer=model_type_dict[args.model_type],
    #                         output_size=args.embedding, **vars(args))
    model_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    encode = encoder(model_input,
                     layer=model_type_dict[args.model_type],
                     output_size=args.embedding,
                     **vars(args))

    scale_param_input = Input(shape=(2, ), name='scale_params')
    merged = merge([encode, scale_param_input], mode='concat')

    out = Dense(args.size + 2, activation='relu')(merged)
    out = Dense(Y.shape[-1], activation='softmax')(out)
    model = Model([model_input, scale_param_input], out)

    run = ku.get_run_id(**vars(args))
    if args.pretrain:
        for layer in model.layers:
            layer.trainable = False
        pretrain_weights = os.path.join('keras_logs', args.pretrain, run,
                                        'weights.h5')
    else:
        pretrain_weights = None

    history = ku.train_and_log(
        [X[train], np.c_[means, scales][train]],
        Y[train],
        run,
        model,
        metrics=['accuracy'],
        validation_data=([X[valid], np.c_[means, scales][valid]], Y[valid]),
        pretrain_weights=pretrain_weights,
        **vars(args))
    return X, X_raw, Y, model, args
Beispiel #11
0
testLabel = dfLabel.iloc[~mask, 1].values

# simply impute missing data with mean value
from sklearn.preprocessing import Imputer
fill_NaN = Imputer(missing_values=np.nan, strategy='mean', axis=1)
imptrainData = pd.DataFrame(fill_NaN.fit_transform(trainData))
imptestData = pd.DataFrame(fill_NaN.fit_transform(testData))

# normalize features
nortrainData = (imptrainData - imptrainData.mean()) / imptrainData.std()
nortestData = (imptestData - imptestData.mean()) / imptestData.std()

#%% reduce feature dimension by autoencoder
from autoencoder import encoder

output_train, output_test = encoder(nortrainData, nortestData)

#%%%%%%%%%%%%%%%% use gbm #############
from sklearn.ensemble import GradientBoostingClassifier  #GBM algorithm
from sklearn.grid_search import GridSearchCV  #Performing grid search
from gbmModel import modelfit
from gbmModel import modelvali

#All of the rc settings are stored in a dictionary-like variable called matplotlib.rcParams
from matplotlib.pylab import rcParams
rcParams['figure.figsize'] = 12, 4

gbm0 = GradientBoostingClassifier(random_state=10)
#gbm_tuned = GradientBoostingClassifier(learning_rate=0.005, n_estimators=1200,max_depth=9, min_samples_split=1200, min_samples_leaf=60, subsample=0.85, random_state=10, max_features=7,warm_start=True)
alggbm = modelfit(gbm0, output_train, trainLabel)
#%%
def classification(trainSet, trainLabelsSet, testSet, testLabelsSet,
                   autoencoderPath, layers, fcNodes, batchSize, epochs):
    trainImages = []
    with open(trainSet, "rb") as f:
        magicNum = int.from_bytes(f.read(4), byteorder="big")
        numOfImages = int.from_bytes(f.read(4), byteorder="big")
        dx = int.from_bytes(f.read(4), byteorder="big")
        dy = int.from_bytes(f.read(4), byteorder="big")

        dimensions = dx * dy
        # https://www.kaggle.com/hojjatk/read-mnist-dataset
        buf = f.read(dimensions * numOfImages)
        trainImages = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
        trainImages = trainImages.reshape(numOfImages, dx, dy)

    trainImages = np.array(trainImages).reshape(-1, dx, dy,
                                                1).astype('float32') / 255.

    testImages = []
    with open(testSet, "rb") as f:
        magicNum = int.from_bytes(f.read(4), byteorder="big")
        numOfImages = int.from_bytes(f.read(4), byteorder="big")
        dx = int.from_bytes(f.read(4), byteorder="big")
        dy = int.from_bytes(f.read(4), byteorder="big")

        dimensions = dx * dy
        # https://www.kaggle.com/hojjatk/read-mnist-dataset
        buf = f.read(dimensions * numOfImages)
        testImages = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
        testImages = testImages.reshape(numOfImages, dx, dy)

    testImages = np.array(testImages).reshape(-1, dx, dy,
                                              1).astype('float32') / 255.

    trainLabels = []
    with open(trainLabelsSet, "rb") as f:
        magicNum = int.from_bytes(f.read(4), byteorder="big")
        numOfImages = int.from_bytes(f.read(4), byteorder="big")

        buf = f.read(1 * numOfImages)
        trainLabels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
    trainLabels = np.array(trainLabels)

    testLabels = []
    with open(testLabelsSet, "rb") as f:
        magicNum = int.from_bytes(f.read(4), byteorder="big")
        numOfImages = int.from_bytes(f.read(4), byteorder="big")

        # https://www.kaggle.com/hojjatk/read-mnist-dataset
        buf = f.read(1 * numOfImages)
        testLabels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)

    testLabels = np.array(testLabels)

    #conversion to one-hot arrays
    trainCatLabels = to_categorical(trainLabels)
    testCatLabels = to_categorical(testLabels)

    inChannel = 1
    inputImage = Input(shape=(dx, dy, inChannel))

    xTrain, xValid, trainLabel, validLabel = train_test_split(trainImages,
                                                              trainCatLabels,
                                                              test_size=0.2,
                                                              random_state=13)

    autoencoder = keras.models.load_model(autoencoderPath)
    if layers != (int(len(autoencoder.layers) / 2) - 1):
        sys.exit(
            "We are sorry to announce that no encoder exists with the given number of layers in your autoencoder. Please try again."
        )
    # autoencoderWeights = autoencoder.load_weights('./autoencoderWeights.h5')

    encode = encoder(inputImage, layers)
    fullModel = Model(inputImage, fullyConnected(encode, fcNodes))

    for l1, l2 in zip(fullModel.layers[:layers], autoencoder.layers[0:layers]):
        l1.set_weights(l2.get_weights())

    for layer in fullModel.layers[0:(int(len(autoencoder.layers) / 2) - 1)]:
        layer.trainable = False

    keras.callbacks.EarlyStopping(monitor='loss',
                                  min_delta=0,
                                  patience=2,
                                  verbose=0,
                                  mode='auto')
    keras.callbacks.EarlyStopping(monitor='accuracy',
                                  min_delta=0,
                                  patience=2,
                                  verbose=0,
                                  mode='auto')
    keras.callbacks.EarlyStopping(monitor='val_loss',
                                  min_delta=0,
                                  patience=2,
                                  verbose=0,
                                  mode='auto')
    keras.callbacks.EarlyStopping(monitor='val_accuracy',
                                  min_delta=0,
                                  patience=2,
                                  verbose=0,
                                  mode='auto')

    fullModel.compile(loss=keras.losses.categorical_crossentropy,
                      optimizer=keras.optimizers.Adam(),
                      metrics=['accuracy'])
    classifyTrain = fullModel.fit(xTrain,
                                  trainLabel,
                                  batch_size=batchSize,
                                  epochs=epochs,
                                  verbose=1,
                                  validation_data=(xValid, validLabel))
    fullModel.save_weights('autoencoderClassification.h5')

    for layer in fullModel.layers[0:(int(len(autoencoder.layers) / 2) - 1)]:
        layer.trainable = True

    fullModel.compile(loss=keras.losses.categorical_crossentropy,
                      optimizer=keras.optimizers.Adam(),
                      metrics=['accuracy'])
    classifyTrain = fullModel.fit(xTrain,
                                  trainLabel,
                                  batch_size=batchSize,
                                  epochs=epochs,
                                  verbose=1,
                                  validation_data=(xValid, validLabel))
    fullModel.save_weights('classificationComplete.h5')

    testEval = fullModel.evaluate(testImages, testCatLabels, verbose=0)
    predictedClasses = fullModel.predict(testImages)
    predictedClasses = np.argmax(np.round(predictedClasses), axis=1)

    inp = input(
        "We have produced the classification model. Would you like to start with predictions? (Y/n) "
    )
    if inp == 'Y' or inp == 'y':
        print("Started making predictions...")

        correct = np.where(predictedClasses == testLabels)[0]
        print("Found " + str(len(correct)) + " correct labels")

        for i, correct in enumerate(correct[:12]):
            plt.subplot(4, 3, i + 1)
            plt.imshow(testImages[correct].reshape(28, 28),
                       cmap='gray',
                       interpolation='none')
            plt.title("Predicted {}, Class {}".format(
                predictedClasses[correct], testLabels[correct]))
            plt.tight_layout()
        plt.savefig('correctPredictions.png')
        plt.savefig('correctPredictions.png', bbox_inches='tight')
        plt.show()

        incorrect = np.where(predictedClasses != testLabels)[0]
        print("Found " + str(len(incorrect)) + " incorrect labels")
        plt.figure()

        for i, incorrect in enumerate(incorrect[:12]):
            plt.subplot(4, 3, i + 1)
            plt.imshow(testImages[incorrect].reshape(28, 28),
                       cmap='gray',
                       interpolation='none')
            plt.title("Predicted {}, Class {}".format(
                predictedClasses[incorrect], testLabels[incorrect]))
            plt.tight_layout()
        plt.savefig('incorrectPredictions.png')
        plt.savefig('incorrectPredictions.png', bbox_inches='tight')
        plt.show()

        print("More specifically:\n")
        print(classification_report(testLabels, predictedClasses))

    inp = input(
        "Would you like to plot your experiment's loss and accuracy results? (Y/n) "
    )
    if inp == 'Y' or inp == 'y':
        accuracy = classifyTrain.history['accuracy']
        val_accuracy = classifyTrain.history['val_accuracy']
        loss = classifyTrain.history['loss']
        val_loss = classifyTrain.history['val_loss']
        epochs = range(len(accuracy))
        plt.plot(epochs, accuracy, 'bo', label='Training accuracy')
        plt.plot(epochs, val_accuracy, 'b', label='Validation accuracy')
        plt.title('Training and validation accuracy')
        plt.legend()
        plt.figure()
        plt.plot(epochs, loss, 'bo', label='Training loss')
        plt.plot(epochs, val_loss, 'b', label='Validation loss')
        plt.title('Training and validation loss')
        plt.legend()
        plt.savefig('overfittingClassCheck.png')
        plt.savefig('overfittingClassCheck.png', bbox_inches='tight')
        plt.show()
        # return 0

    inp = input(
        "Would you like to repeat your experiment with different hyperparameter values? (Y/n) "
    )
    if inp == 'Y' or inp == 'y':
        return 1

    return 0
def main(args=None):
    """Train an autoencoder model from `LightCurve` objects saved in
    `args.survey_files`.
    
    args: dict
        Dictionary of values to override default values in `keras_util.parse_model_args`;
        can also be passed via command line. See `parse_model_args` for full list of
        possible arguments.
    """
    args = ku.parse_model_args(args)

    np.random.seed(0)

    if not args.survey_files:
        raise ValueError("No survey files given")
    lc_lists = [joblib.load(f) for f in args.survey_files]
    n_reps = [max(len(y) for y in lc_lists) // len(x) for x in lc_lists]
    combined = sum([x * i for x, i in zip(lc_lists, n_reps)], [])
    if args.lomb_score:
        combined = [lc for lc in combined if lc.best_score >= args.lomb_score]
    if args.ss_resid:
        combined = [lc for lc in combined if lc.ss_resid <= args.ss_resid]
    split = [el for lc in combined for el in lc.split(args.n_min, args.n_max)]
    if args.period_fold:
        for lc in split:
            lc.period_fold()
    X_list = [np.c_[lc.times, lc.measurements, lc.errors] for lc in split]

    X_raw = pad_sequences(X_list, value=np.nan, dtype='float', padding='post')

    model_type_dict = {'gru': GRU, 'lstm': LSTM, 'vanilla': SimpleRNN}
    X, means, scales, wrong_units = preprocess(X_raw, args.m_max)
    main_input = Input(shape=(X.shape[1], X.shape[-1]), name='main_input')
    aux_input = Input(shape=(X.shape[1], X.shape[-1] - 1), name='aux_input')
    model_input = [main_input, aux_input]
    encode = encoder(main_input,
                     layer=model_type_dict[args.model_type],
                     output_size=args.embedding,
                     **vars(args))
    decode = decoder(
        encode,
        num_layers=args.decode_layers
        if args.decode_layers else args.num_layers,
        layer=model_type_dict[args.decode_type if args.decode_type else args.
                              model_type],
        n_step=X.shape[1],
        aux_input=aux_input,
        **{k: v
           for k, v in vars(args).items() if k != 'num_layers'})
    model = Model(model_input, decode)

    run = ku.get_run_id(**vars(args))

    errors = X_raw[:, :, 2] / scales
    sample_weight = 1. / errors
    sample_weight[np.isnan(sample_weight)] = 0.0
    X[np.isnan(X)] = 0.

    history = ku.train_and_log(
        {
            'main_input': X,
            'aux_input': np.delete(X, 1, axis=2)
        },
        X[:, :, [1]],
        run,
        model,
        sample_weight=sample_weight,
        errors=errors,
        validation_split=0.0,
        **vars(args))

    return X, X_raw, model, means, scales, wrong_units, args
Beispiel #14
0
def classification_model(X, drop_out_prob = 0.8):
    layer_1 = encoder(X,drop_prob = drop_out_prob)
    layer_1_norm = tf.layers.batch_normalization(layer_1, center = True, scale = True)
    scores = tf.add(tf.matmul(layer_1_norm, W1), b1)
    prediction = tf.nn.softmax(scores)
    return prediction
Beispiel #15
0
def get_model(sess, image_shape=(80, 160, 3), gf_dim=64, df_dim=64, batch_size=64,
              name="transition", gpu=0):
    K.set_session(sess)
    checkpoint_dir = './outputs/results_' + name
    with tf.variable_scope(name):
      # sizes
      ch = image_shape[2]
      rows = [image_shape[0]/i for i in [16, 8, 4, 2, 1]]
      cols = [image_shape[1]/i for i in [16, 8, 4, 2, 1]]

      G = autoencoder.generator(7*(time+out_leng-1), gf_dim, ch, rows, cols)
      G.compile("sgd", "mse")
      E = autoencoder.encoder(batch_size*(time+out_leng), df_dim, ch, rows, cols)
      E.compile("sgd", "mse")

      G.trainable = False
      E.trainable = False

      # nets
      T = transition(batch_size)
      T.compile("sgd", "mse")
      t_vars = T.trainable_weights
      print "T.shape: ", T.output_shape

      Img = Input(batch_shape=(batch_size, time+out_leng,) + image_shape)
      I = K.reshape(Img, (batch_size*(time+out_leng),)+image_shape)
      code = E(I)[0]
      code = K.reshape(code, (batch_size, time+out_leng, z_dim))
      target = code[:, 1:, :]
      inp = code[:, :time, :]
      out = T(inp)
      G_dec = G(K.reshape(out[:7, :, :], (-1, z_dim)))

      # costs
      loss = tf.reduce_mean(tf.square(target - out))
      print "Transition variables:"
      for v in t_vars:
        print v.name

      t_optim = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(loss, var_list=t_vars)

      tf.initialize_all_variables().run()

    # summaries
    sum_loss = tf.scalar_summary("loss", loss)
    sum_e_mean = tf.histogram_summary("e_mean", code)
    sum_out = tf.histogram_summary("out", out)
    sum_dec = tf.image_summary("E", G_dec)

    # saver
    saver = tf.train.Saver()
    t_sum = tf.merge_summary([sum_e_mean, sum_out, sum_dec, sum_loss])
    writer = tf.train.SummaryWriter("/tmp/logs/"+name, sess.graph)

    # functions
    def train_d(images, z, counter, sess=sess):
      return 0, 0, 0

    def train_g(images, z, counter, sess=sess):
      outputs = [loss, G_dec, t_sum, t_optim]
      outs = sess.run(outputs, feed_dict={Img: images, K.learning_phase(): 1})
      gl, samples, sums = outs[:3]
      writer.add_summary(sums, counter)
      images = images.reshape((-1, 80, 160, 3))[:64]
      samples = samples.reshape((-1, 80, 160, 3))[:64]
      return gl, samples, images

    def f_load():
      try:
        return load(sess, saver, checkpoint_dir, name)
      except:
        print("Loading weights via Keras")
        T.load_weights(checkpoint_dir+"/T_weights.keras")

    def f_save(step):
      save(sess, saver, checkpoint_dir, step, name)
      T.save_weights(checkpoint_dir+"/T_weights.keras", True)

    def sampler(z, x):
      video = np.zeros((128, 80, 160, 3))
      print "Sampling..."
      for i in range(128):
        print i
        x = x.reshape((-1, 80, 160, 3))
        # code = E.predict(x, batch_size=batch_size*(time+1))[0]
        code = sess.run([E(I)[0]], feed_dict={I: x, K.learning_phase(): 1})[0]
        code = code.reshape((batch_size, time+out_leng, z_dim))
        inp = code[:, :time]
        outs = T.predict(inp, batch_size=batch_size)
        # imgs = G.predict(out, batch_size=batch_size)
        imgs = sess.run([G_dec], feed_dict={out: outs, K.learning_phase(): 1})[0]
        video[i] = imgs[0]
        x = x.reshape((batch_size, time+out_leng, 80, 160, 3))
        x[0, :-1] = x[0, 1:]
        x[0, -1] = imgs[0]

      video = video.reshape((batch_size, 2, 80, 160, 3))
      return video[:, 0], video[:, 1]

    G.load_weights(G_file_path)
    E.load_weights(E_file_path)

    return train_g, train_d, sampler, f_save, f_load, [G, E, T]
        (-1, settings['image_size'][0], settings['image_size'][1],
         settings['num_channels'])).astype(np.float32)
    train_labels = to_categorical(train_labels)
    valid_dataset = valid_dataset.reshape(
        (-1, settings['image_size'][0], settings['image_size'][1],
         settings['num_channels'])).astype(np.float32)
    valid_labels_oh = to_categorical(valid_labels)

    print('Training set', train_dataset.shape, train_labels.shape)
    print('Validation set', valid_dataset.shape, valid_labels_oh.shape)

    input_img = Input(shape=(settings['image_size'][0],
                             settings['image_size'][1],
                             settings['num_channels']))

    encode = encoder(input_img)
    full_model = Model(input_img, fc(encode))

    full_model.compile(loss=keras.losses.categorical_crossentropy,
                       optimizer=keras.optimizers.Adam(lr=6e-5, decay=0.85e-7),
                       metrics=['accuracy'])

    classify_train = full_model.fit(train_dataset,
                                    train_labels,
                                    batch_size=256,
                                    epochs=600,
                                    verbose=1,
                                    validation_data=(valid_dataset,
                                                     valid_labels_oh))

    full_model.save_weights('classification_complete_noauto.h5')
def main():
    '''Configuration'''
    parser = argparse.ArgumentParser()
    parser.add_argument('--encoder', type=str)
    parser.add_argument('--segmentation', type=str)
    parser.add_argument('--discriminator', type=str)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--LAMBDA', type=int)
    parser.add_argument('--save_dir', type=str)
    args = parser.parse_args()

    LAMBDA = args.LAMBDA
    batch_size = args.batch_size
    learning_rate = 1e-4
    num_epochs = 500

    device = torch.device('cuda')

    save_dir = args.save_dir
    if os.path.exists(save_dir) == False:
        os.mkdir(save_dir)

    enc_path = args.encoder
    dis_path = args.discriminator
    seg_path = args.segmentation
    '''Initialize dataset'''
    train_set = ppmi_pairs(mode='train')
    val_set = ppmi_pairs(mode='val')
    '''Initialize networks'''
    # init model encrypter
    Encrypter = encoder(1, 1, 16).to(device)
    if os.path.exists(enc_path):
        Encrypter.load_state_dict(torch.load(enc_path, map_location=device))
    Encrypter.train()
    # init discriminator
    Discriminator = discriminator().to(device)
    if os.path.exists(dis_path):
        Discriminator.load_state_dict(torch.load(dis_path,
                                                 map_location=device))
    Discriminator.train()
    # init segmentator
    Segmentator = segnet(1, 6, 32).to(device)
    if os.path.exists(seg_path):
        Segmentator.load_state_dict(torch.load(seg_path, map_location=device))
    Segmentator.train()
    #
    models = {'enc': Encrypter, 'seg': Segmentator, 'dis': Discriminator}
    '''Initialize optimizer'''
    # declare loss function
    Segment_criterion = Dice_Loss()
    Discrimination_criterion = torch.nn.CrossEntropyLoss()
    criterions = {'seg': Segment_criterion, 'dis': Discrimination_criterion}

    # init optimizer
    params_es = [{
        "params": Encrypter.parameters()
    }, {
        "params": Segmentator.parameters()
    }]
    optimizer_es = torch.optim.Adam(params_es, lr=learning_rate)

    params_d = [{"params": Discriminator.parameters()}]
    optimizer_d = torch.optim.Adam(params_d, lr=learning_rate)

    optimizers = {'es': optimizer_es, 'dis': optimizer_d}

    for epoch in range(num_epochs):
        print('|==========================\nEPOCH:{}'.format(epoch + 1))
        '''Trainnig'''
        _, _, _, _,\
        models, optimizers = train_epoch(models, optimizers, criterions, LAMBDA, train_set, batch_size,device)
        '''Validation'''
        if (epoch + 1) % 1 == 0:
            val_epoch(models, criterions, val_set, batch_size, device)

        #save model
        if (epoch + 1) % 1 == 0:
            '''save models'''
            models_dir = os.path.join(save_dir, 'models')
            if os.path.exists(models_dir) == False:
                os.mkdir(models_dir)
            torch.save(models['enc'].state_dict(),
                       os.path.join(models_dir, 'enc.pt'))
            torch.save(models['seg'].state_dict(),
                       os.path.join(models_dir, 'seg.pt'))
            torch.save(models['dis'].state_dict(),
                       os.path.join(models_dir, 'dis.pt'))
            torch.save(optimizers['es'].state_dict(),
                       os.path.join(models_dir, 'optim_es.pt'))
            torch.save(optimizers['dis'].state_dict(),
                       os.path.join(models_dir, 'optim_dis.pt'))

    return