def add_latents_to_dataset_using_tensors(args, sess, tensors, data): """ Get latent representations from model. Args: args: Arguments from parser in train_grocerystore.py. sess: Tensorflow session. tensors: Tensors used for extracting latent representations. data: Data used during epoch. Returns: Data dictionary filled with latent representations. """ latents = sess.run(tensors['latents'], feed_dict={tensors['x']: data['features']}) data['latents'] = latents if args.use_private: latents_ux = sess.run(tensors['latents_ux'], feed_dict={tensors['x']: data['features']}) data['latents_ux'] = latents_ux if args.use_text: all_captions = load_captions(data['captions'], data['labels']) latents_uw = sess.run( tensors['latents_uw'], feed_dict={tensors['captions']: all_captions}) data['latents_uw'] = latents_uw if args.use_iconic: batch_size = args.batch_size n_examples = len(data['iconic_image_paths']) n_batches = int(np.ceil(n_examples / batch_size)) latents_ui = np.zeros([n_examples, args.z_dim]) for i in range(n_batches): start = i * batch_size end = start + batch_size if end > n_examples: end = n_examples iconic_images = load_iconic_images( data['iconic_image_paths'][start:end]) latents_ui[start:end] = sess.run( tensors['latents_ui'], feed_dict={tensors['iconic_images']: iconic_images}) data['latents_ui'] = latents_ui return data
def run_training_epoch(args, data, model, hyperparams, session, train_op=None, shuffle=False, mode='train'): """ Execute training epoch for Autoencoder. Args: args: Arguments from parser in train_grocerystore.py. data: Data used during epoch. model: Model used during epoch. hyperparams: Hyperparameters for training. session: Tensorflow session. train_op: Op for computing gradients and updating parameters in model. shuffle: For shuffling data before epoch. mode: Training/validation mode. Returns: Metrics in python dictionary. """ # Hyperparameters batch_size = hyperparams['batch_size'] dropout_rate = hyperparams['dropout_rate'] kl_weight = hyperparams['kl_weight'] is_training = hyperparams['is_training'] # Data features = data['features'] labels = data['labels'] captions = data['captions'] iconic_image_path = data['iconic_image_paths'] n_classes = data['n_classes'] n_examples = len(features) n_batches = int(np.ceil(n_examples/batch_size)) if shuffle: perm = np.random.permutation(n_examples) features = features[perm] iconic_image_path = iconic_image_path[perm] labels = labels[perm] total_loss = 0. x_loss = 0. i_loss = 0. w_loss = 0. clf_loss = 0. accuracy = 0. for i in range(n_batches): start = i * batch_size end = start + batch_size if end > n_examples: end = n_examples # Prepare batch and hyperparameters x_batch = features[start:end] i_batch = load_iconic_images(iconic_image_path[start:end]) captions_batch = load_captions(captions, labels[start:end]) labels_batch = onehot_encode(labels[start:end], n_classes) feed_dict={model.x: x_batch, model.iconic_images: i_batch, model.captions: captions_batch, model.labels: labels_batch, model.dropout_rate: dropout_rate, model.is_training: is_training} if mode == 'train': # Training step train_step_results = session.run([train_op] + model.log_var, feed_dict=feed_dict) total_loss += train_step_results[1] x_loss += np.sum(train_step_results[2]) i_loss += np.sum(train_step_results[3]) w_loss += np.sum(train_step_results[4]) clf_loss += np.sum(train_step_results[5]) accuracy += np.sum(train_step_results[6]) elif mode == 'val': # Validation step val_step_results = session.run(model.val_log_var, feed_dict=feed_dict) total_loss += val_step_results[0] x_loss += np.sum(val_step_results[1]) i_loss += np.sum(val_step_results[2]) w_loss += np.sum(val_step_results[3]) clf_loss += np.sum(val_step_results[4]) accuracy += np.sum(val_step_results[5]) else: raise ValueError("Argument \'mode\' %s doesn't exist!" %mode) # Epoch finished, return results. clf_loss and accuracy are zero if args.classifier_head is False results = {'total_loss': total_loss / n_batches, 'x_loss': x_loss / n_examples, 'i_loss': i_loss / n_examples, 'w_loss': w_loss / n_examples, 'clf_loss': clf_loss / n_examples, 'accuracy': accuracy / n_batches} return results