def main(_): tf.enable_v2_behavior() ############################################################################## ######################### Data loading and processing ######################## ############################################################################## print('Loading data') # with gfile.GFile(transition_path, 'r') as f: # transitions = np.load(f) with gfile.GFile(transition_state_path, 'r') as f: states = np.load(f) states = np.float32(states) with gfile.GFile(transition_label_path, 'r') as f: captions = pickle.load(f) with gfile.GFile(answer_path, 'r') as f: answers = pickle.load(f) with gfile.GFile(vocab_path, 'r') as f: vocab_list = f.readlines() vocab_list = [w[:-1].decode('utf-8') for w in vocab_list] vocab_list = ['eos', 'sos', 'nothing'] + vocab_list vocab_list[-1] = 'to' v2i, i2v = wv.create_look_up_table(vocab_list) encode_fn = wv.encode_text_with_lookup_table(v2i) decode_fn = wv.decode_with_lookup_table(i2v) caption_decoding_map = {v: k for k, v in captions[0].items()} decompressed_captions = [] for caption in captions[1:]: new_caption = [] for c in caption: new_caption.append(caption_decoding_map[c]) decompressed_captions.append(new_caption) captions = decompressed_captions encoded_captions = [] new_answers = [] for i, all_cp in enumerate(captions): for cp in all_cp: encoded_captions.append(np.array(encode_fn(cp))) for a in answers[i]: new_answers.append(float(a)) all_caption_n = len(encoded_captions) encoded_captions = np.array(encoded_captions) encoded_captions = pad_to_max_length(encoded_captions) answers = np.float32(new_answers) obs_idx, caption_idx = [], [] curr_caption_idx = 0 for i, _ in enumerate(states): for cp in captions[i]: obs_idx.append(i) caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_caption_n obs_idx = np.array(obs_idx) caption_idx = np.array(caption_idx) all_idx = np.arange(len(caption_idx)) train_idx = all_idx[:int(len(all_idx) * 0.7)] test_idx = all_idx[int(len(all_idx) * 0.7):] print('Number of training examples: {}'.format(len(train_idx))) print('Number of test examples: {}\n'.format(len(test_idx))) ############################################################################## ############################# Training Setup ################################# ############################################################################## embedding_dim = 32 units = 64 vocab_size = len(vocab_list) batch_size = 128 max_sequence_length = 21 encoder_config = {'name': 'state', 'embedding_dim': 64} decoder_config = { 'name': 'state', 'word_embedding_dim': 64, 'hidden_units': 512, 'vocab_size': len(vocab_list), } encoder = get_answering_encoder(encoder_config) decoder = get_answering_decoder(decoder_config) projection_layer = tf.keras.layers.Dense(1, activation='sigmoid', name='answering_projection') optimizer = tf.keras.optimizers.Adam(1e-4) bce = tf.keras.losses.BinaryCrossentropy() @tf.function def compute_loss(obs, instruction, target): instruction = tf.expand_dims(instruction, axis=-1) hidden = decoder.reset_state(batch_size=target.shape[0]) features = encoder(obs) for i in tf.range(max_sequence_length): _, hidden, _ = decoder(instruction[:, i], features, hidden) projection = tf.squeeze(projection_layer(hidden), axis=1) loss = bce(target, projection) return loss, projection @tf.function def train_step(obs, instruction, target): with tf.GradientTape() as tape: loss, _ = compute_loss(obs, instruction, target) trainable_variables = encoder.trainable_variables + decoder.trainable_variables + projection_layer.trainable_variables gradients = tape.gradient(loss, trainable_variables) optimizer.apply_gradients(zip(gradients, trainable_variables)) return loss ############################################################################## ############################# Training Loop ################################## ############################################################################## print('Start training...\n') start_epoch = 0 if FLAGS.save_dir: checkpoint_path = FLAGS.save_dir ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, projection_layer=projection_layer, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) if ckpt_manager.latest_checkpoint: start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1]) epochs = 400 step_per_epoch = int(all_caption_n / batch_size) previous_best, previous_best_accuracy = 100., 0.0 for epoch in range(start_epoch, epochs): start = time.time() total_loss = 0 for batch in range(step_per_epoch): batch_idx = np.random.choice(train_idx, size=batch_size) input_tensor = tf.convert_to_tensor(states[obs_idx[batch_idx], :]) instruction = tf.convert_to_tensor( encoded_captions[caption_idx[batch_idx]]) target = tf.convert_to_tensor(answers[caption_idx[batch_idx]]) batch_loss = train_step(input_tensor, instruction, target) total_loss += batch_loss if batch % 1000 == 0: print('Epoch {} Batch {} Loss {:.4f}'.format( epoch, batch, batch_loss.numpy())) if epoch % 5 == 0 and FLAGS.save_dir: test_total_loss = 0 accuracy = 0 for batch in range(10): batch_idx = np.arange(batch_size) + batch * batch_size idx = test_idx[batch_idx] input_tensor = tf.convert_to_tensor(states[obs_idx[idx], :]) instruction = tf.convert_to_tensor( encoded_captions[caption_idx[idx]]) target = tf.convert_to_tensor(answers[caption_idx[idx]]) t_loss, prediction = compute_loss(input_tensor, instruction, target) test_total_loss += t_loss accuracy += np.mean( np.float32(np.float32(prediction > 0.5) == target)) test_total_loss /= 10. accuracy /= 10. if accuracy > previous_best_accuracy: previous_best_accuracy, previous_best = accuracy, test_total_loss ckpt_manager.save(checkpoint_number=epoch) print('\nEpoch {} | Loss {:.6f} | Val loss {:.6f} | Accuracy {:.3f}'. format(epoch + 1, total_loss / step_per_epoch, previous_best, previous_best_accuracy)) print('Time taken for 1 epoch {:.6f} sec\n'.format(time.time() - start)) if epoch % 10 == 0: test_total_loss = 0 accuracy = 0 for batch in range(len(test_idx) // batch_size): batch_idx = np.arange(batch_size) + batch * batch_size idx = test_idx[batch_idx] input_tensor = tf.convert_to_tensor(states[obs_idx[idx], :]) instruction = tf.convert_to_tensor( encoded_captions[caption_idx[idx]]) target = tf.convert_to_tensor(answers[caption_idx[idx]]) t_loss, prediction = compute_loss(input_tensor, instruction, target) test_total_loss += t_loss accuracy += np.mean( np.float32(np.float32(prediction > 0.5) == target)) test_total_loss /= (len(test_idx) // batch_size) accuracy /= (len(test_idx) // batch_size) if accuracy > previous_best_accuracy and FLAGS.save_dir: previous_best_accuracy, previous_best = accuracy, test_total_loss ckpt_manager.save(checkpoint_number=epoch) print('\n====================================================') print('Test Loss {:.6f} | Test Accuracy {:.3f}'.format( test_total_loss, accuracy)) print('====================================================\n')
def main(_): tf.enable_v2_behavior() ############################################################################## ######################### Data loading and processing ######################## ############################################################################## print('Loading data') with gfile.GFile(transition_path, 'r') as f: transitions = np.load(f) if np.max(transitions) > 1.0: transitions = transitions / 255.0 with gfile.GFile(synthetic_transition_path, 'r') as f: synthetic_transitions = np.load(f) if np.max(synthetic_transitions) > 1.0: synthetic_transitions = synthetic_transitions / 255.0 with gfile.GFile(transition_label_path, 'r') as f: captions = pickle.load(f) with gfile.GFile(synthetic_transition_label_path, 'r') as f: synthetic_captions = pickle.load(f) with gfile.GFile(vocab_path, 'r') as f: vocab_list = f.readlines() vocab_list = [w[:-1].decode('utf-8') for w in vocab_list] vocab_list = ['eos', 'sos'] + vocab_list v2i, i2v = wv.create_look_up_table(vocab_list) encode_fn = wv.encode_text_with_lookup_table(v2i) decode_fn = wv.decode_with_lookup_table(i2v) encoded_captions = [] for all_cp in captions: for cp in all_cp: cp = 'sos ' + cp + ' eos' encoded_captions.append(np.array(encode_fn(cp))) synthetic_encoded_captions = [] for all_cp in synthetic_captions: for cp in all_cp: cp = 'sos ' + cp + ' eos' synthetic_encoded_captions.append(np.array(encode_fn(cp))) all_caption_n = len(encoded_captions) all_synthetic_caption_n = len(synthetic_encoded_captions) encoded_captions = np.array(encoded_captions) encoded_captions = pad_to_max_length(encoded_captions, max_l=15) synthetic_encoded_captions = np.array(synthetic_encoded_captions) synthetic_encoded_captions = pad_to_max_length(synthetic_encoded_captions, max_l=15) obs_idx, caption_idx, negative_caption_idx = [], [], [] curr_caption_idx = 0 for i, _ in enumerate(transitions): for cp in captions[i]: obs_idx.append(i) if 'nothing' not in cp: caption_idx.append(curr_caption_idx) else: negative_caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_caption_n synthetic_obs_idx, synthetic_caption_idx = [], [] synthetic_negative_caption_idx = [] curr_caption_idx = 0 for i, _ in enumerate(synthetic_transitions): for cp in synthetic_captions[i]: synthetic_obs_idx.append(i) if 'nothing' not in cp: synthetic_caption_idx.append(curr_caption_idx) else: synthetic_negative_caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_synthetic_caption_n obs_idx = np.array(obs_idx) caption_idx = np.array(caption_idx) negative_caption_idx = np.array(negative_caption_idx) all_idx = np.arange(len(caption_idx)) train_idx = all_idx[:int(len(all_idx) * 0.8)] test_idx = all_idx[int(len(all_idx) * 0.8):] print('Number of training examples: {}'.format(len(train_idx))) print('Number of test examples: {}\n'.format(len(test_idx))) synthetic_obs_idx = np.array(synthetic_obs_idx) synthetic_caption_idx = np.array(synthetic_caption_idx) synthetic_negative_caption_idx = np.array(synthetic_negative_caption_idx) synthetic_all_idx = np.arange(len(synthetic_caption_idx)) synthetic_train_idx = synthetic_all_idx[:int(len(synthetic_all_idx) * 0.8)] synthetic_test_idx = synthetic_all_idx[int(len(synthetic_all_idx) * 0.8):] print('Number of synthetic training examples: {}'.format( len(synthetic_train_idx))) print('Number of synthetic test examples: {}\n'.format( len(synthetic_test_idx))) def sample_batch(data_type, batch_size, mode='train'): is_synthetic = data_type == 'synthetic' transitions_s = synthetic_transitions if is_synthetic else transitions encoded_captions_s = synthetic_encoded_captions if is_synthetic else encoded_captions obs_idx_s = synthetic_obs_idx if is_synthetic else obs_idx caption_idx_s = synthetic_caption_idx if is_synthetic else caption_idx all_idx_s = synthetic_all_idx if is_synthetic else all_idx train_idx_s = synthetic_train_idx if is_synthetic else train_idx test_idx_s = synthetic_test_idx if is_synthetic else test_idx if mode == 'train': batch_idx_s = np.random.choice(train_idx_s, size=batch_size) else: batch_idx_s = np.random.choice(test_idx_s, size=batch_size) input_tensor = tf.convert_to_tensor( np.concatenate([ transitions_s[obs_idx_s[batch_idx_s], 1, :], transitions_s[obs_idx_s[batch_idx_s], 1, :] ])) positive_idx = caption_idx_s[batch_idx_s] negative_idx = caption_idx_s[np.random.choice(train_idx_s, size=batch_size)] caption_tensor = tf.convert_to_tensor( np.concatenate([ encoded_captions_s[positive_idx], encoded_captions_s[negative_idx] ], axis=0)) target_tensor = tf.convert_to_tensor( np.float32( np.concatenate([np.ones(batch_size), np.zeros(batch_size)], axis=0))) return input_tensor, caption_tensor, target_tensor ############################################################################## ############################# Training Setup ################################# ############################################################################## embedding_dim = 32 units = 64 vocab_size = len(vocab_list) batch_size = 64 max_sequence_length = 15 encoder_config = {'name': 'image', 'embedding_dim': 64} decoder_config = { 'name': 'attention', 'word_embedding_dim': 64, 'hidden_units': 256, 'vocab_size': len(vocab_list), } encoder = get_answering_encoder(encoder_config) decoder = get_answering_decoder(decoder_config) projection_layer = tf.keras.layers.Dense(1, activation='sigmoid', name='answering_projection') optimizer = tf.keras.optimizers.Adam(1e-4) bce = tf.keras.losses.BinaryCrossentropy() @tf.function def compute_loss(obs, instruction, target, training): print('Build compute loss...') instruction = tf.expand_dims(instruction, axis=-1) hidden = decoder.reset_state(batch_size=target.shape[0]) features = encoder(obs, training=training) for i in tf.range(max_sequence_length): _, hidden, _ = decoder(instruction[:, i], features, hidden, training=training) projection = tf.squeeze(projection_layer(hidden), axis=1) loss = bce(target, projection) return loss, projection @tf.function def train_step(obs, instruction, target): print('Build train step...') with tf.GradientTape() as tape: loss, _ = compute_loss(obs, instruction, target, True) trainable_variables = encoder.trainable_variables + decoder.trainable_variables + projection_layer.trainable_variables print('num trainable: ', len(trainable_variables)) gradients = tape.gradient(loss, trainable_variables) optimizer.apply_gradients(zip(gradients, trainable_variables)) return loss ############################################################################## ############################# Training Loop ################################## ############################################################################## print('Start training...\n') start_epoch = 0 if FLAGS.save_dir: checkpoint_path = FLAGS.save_dir ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, projection_layer=projection_layer, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) if ckpt_manager.latest_checkpoint: start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1]) epochs = 400 step_per_epoch = int(all_caption_n / batch_size) previous_best, previous_best_accuracy = 100., 0.0 # input_tensor, instruction, target = sample_batch('synthetic', batch_size, # 'train') for epoch in range(start_epoch, epochs): start = time.time() total_loss = 0 for batch in range(step_per_epoch): input_tensor, instruction, target = sample_batch( 'synthetic', batch_size, 'train') batch_loss = train_step(input_tensor, instruction, target) total_loss += batch_loss # print(batch, batch_loss) # print(instruction[0]) # print(encode_fn('nothing')) # print('====================================') if batch % 1000 == 0: print('Epoch {} Batch {} Loss {:.4f}'.format( epoch, batch, batch_loss.numpy())) if epoch % 5 == 0 and FLAGS.save_dir: test_total_loss = 0 accuracy = 0 for batch in range(10): input_tensor, instruction, target = sample_batch( 'synthetic', batch_size, 'test') t_loss, prediction = compute_loss(input_tensor, instruction, target, False) test_total_loss += t_loss accuracy += np.mean( np.float32(np.float32(prediction > 0.5) == target)) test_total_loss /= 10. accuracy /= 10. if accuracy > previous_best_accuracy: previous_best_accuracy, previous_best = accuracy, test_total_loss ckpt_manager.save(checkpoint_number=epoch) print('\nEpoch {} | Loss {:.6f} | Val loss {:.6f} | Accuracy {:.3f}'. format(epoch + 1, total_loss / step_per_epoch, previous_best, previous_best_accuracy)) print('Time taken for 1 epoch {:.6f} sec\n'.format(time.time() - start)) if epoch % 10 == 0: test_total_loss = 0 accuracy = 0 for batch in range(len(test_idx) // batch_size): input_tensor, instruction, target = sample_batch( 'synthetic', batch_size, 'test') t_loss, prediction = compute_loss(input_tensor, instruction, target, training=False) test_total_loss += t_loss accuracy += np.mean( np.float32(np.float32(prediction > 0.5) == target)) test_total_loss /= (len(test_idx) // batch_size) accuracy /= (len(test_idx) // batch_size) if accuracy > previous_best_accuracy and FLAGS.save_dir: previous_best_accuracy, previous_best = accuracy, test_total_loss ckpt_manager.save(checkpoint_number=epoch) print('\n====================================================') print('Test Loss {:.6f} | Test Accuracy {:.3f}'.format( test_total_loss, accuracy)) print('====================================================\n')
def main(_): if FLAGS.use_tf2: tf.enable_v2_behavior() config_content = { 'action_type': FLAGS.action_type, 'obs_type': FLAGS.obs_type, 'reward_shape_val': FLAGS.reward_shape_val, 'use_subset_instruction': FLAGS.use_subset_instruction, 'frame_skip': FLAGS.frame_skip, 'use_polar': FLAGS.use_polar, 'suppress': FLAGS.suppress, 'diverse_scene_content': FLAGS.diverse_scene_content, 'buffer_size': FLAGS.buffer_size, 'use_movement_bonus': FLAGS.use_movement_bonus, 'reward_scale': FLAGS.reward_scale, 'scenario_type': FLAGS.scenario_type, 'img_resolution': FLAGS.img_resolution, 'render_resolution': FLAGS.render_resolution, # agent 'agent_type': FLAGS.agent_type, 'masking_q': FLAGS.masking_q, 'discount': FLAGS.discount, 'instruction_repr': FLAGS.instruction_repr, 'encoder_type': FLAGS.encoder_type, 'learning_rate': FLAGS.learning_rate, 'polyak_rate': FLAGS.polyak_rate, 'trainable_encoder': FLAGS.trainable_encoder, 'embedding_type': FLAGS.embedding_type, # learner 'num_episode': FLAGS.num_episode, 'optimization_steps': FLAGS.optimization_steps, 'batchsize': FLAGS.batchsize, 'sample_new_scene_prob': FLAGS.sample_new_scene_prob, 'max_episode_length': FLAGS.max_episode_length, 'record_atomic_instruction': FLAGS.record_atomic_instruction, 'paraphrase': FLAGS.paraphrase, 'relabeling': FLAGS.relabeling, 'k_immediate': FLAGS.k_immediate, 'future_k': FLAGS.future_k, 'negate_unary': FLAGS.negate_unary, 'min_epsilon': FLAGS.min_epsilon, 'epsilon_decay': FLAGS.epsilon_decay, 'collect_cycle': FLAGS.collect_cycle, 'use_synonym_for_rollout': FLAGS.use_synonym_for_rollout, 'reset_mode': FLAGS.reset_mode, 'maxent_irl': FLAGS.maxent_irl, # relabeler 'sampling_temperature': FLAGS.sampling_temperature, 'generated_label_num': FLAGS.generated_label_num, 'use_labeler_as_reward': FLAGS.use_labeler_as_reward, 'use_oracle_instruction': FLAGS.use_oracle_instruction } if FLAGS.maxent_irl: assert FLAGS.batchsize % FLAGS.irl_parallel_n == 0 config_content['irl_parallel_n'] = FLAGS.irl_parallel_n config_content['irl_sample_goal_n'] = FLAGS.irl_sample_goal_n config_content['relabel_proportion'] = FLAGS.relabel_proportion config_content['entropy_alpha'] = FLAGS.entropy_alpha cfg = Config(config_content) if FLAGS.experiment_confg: cfg.update(get_exp_config(FLAGS.experiment_confg)) save_home = FLAGS.save_dir if FLAGS.save_dir else tf.test.get_temp_dir() if FLAGS.varying: exp_name = 'exp-' for varied_var in FLAGS.varying.split(','): exp_name += str(varied_var) + '=' + str(FLAGS[varied_var].value) + '-' else: exp_name = 'SingleExperiment' save_dir = os.path.join(save_home, exp_name) try: gfile.MkDir(save_home) except gfile.Error as e: print(e) try: gfile.MkDir(save_dir) except gfile.Error as e: print(e) cfg.update(Config({'model_dir': save_dir})) print('############################################################') print(cfg) print('############################################################') env, learner, replay_buffer, agent, extra_components = experiment_setup( cfg, FLAGS.use_tf2, FLAGS.use_nn_relabeling) agent.init_networks() if FLAGS.use_tf2: logger = Logger2(save_dir) else: logger = Logger(save_dir) with gfile.GFile(os.path.join(save_dir, 'config.json'), mode='w+') as f: json.dump(cfg.as_dict(), f, sort_keys=True, indent=4) if FLAGS.save_model and tf.train.latest_checkpoint(save_dir): print('Loading saved weights from {}'.format(save_dir)) agent.load_model(save_dir) if FLAGS.save_model: video_dir = os.path.join(save_dir, 'rollout_cycle_{}.mp4'.format('init')) print('Saving video to {}'.format(video_dir)) learner.rollout( env, agent, video_dir, num_episode=FLAGS.rollout_episode, record_trajectory=FLAGS.record_trajectory) success_rate_ema = -1.0 # Training loop for epoch in range(FLAGS.num_epoch): for cycle in range(FLAGS.num_cycle): stats = learner.learn(env, agent, replay_buffer) if success_rate_ema < 0: success_rate_ema = stats['achieved_goal'] loss_dropped = stats['achieved_goal'] < 0.1 * success_rate_ema far_along_training = stats['global_step'] > 100000 if FLAGS.save_model and loss_dropped and far_along_training: print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') print('Step {}: Loading models due to sudden loss drop D:'.format( stats['global_step'])) print('Dropped from {} to {}'.format(success_rate_ema, stats['achieved_goal'])) agent.load_model(save_dir) print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') continue success_rate_ema = 0.95 * success_rate_ema + 0.05 * stats['achieved_goal'] at_save_interval = stats['global_step'] % FLAGS.save_interval == 0 better_reward = stats['achieved_goal'] > success_rate_ema if FLAGS.save_model and at_save_interval and better_reward: print('Saving model to {}'.format(save_dir)) agent.save_model(save_dir) if FLAGS.save_model and stats['global_step'] % FLAGS.video_interval == 0: video_dir = os.path.join(save_dir, 'rollout_cycle_{}.mp4'.format(cycle)) print('Saving video to {}'.format(video_dir)) test_success_rate = learner.rollout( env, agent, video_dir, record_video=FLAGS.save_video, num_episode=FLAGS.rollout_episode, record_trajectory=FLAGS.record_trajectory) stats['Test Success Rate'] = test_success_rate print('Test Success Rate: {}'.format(test_success_rate)) stats['ema success rate'] = success_rate_ema logger.log(epoch, cycle, stats)
def main(_): tf.enable_v2_behavior() ############################################################################## ######################### Data loading and processing ######################## ############################################################################## print('Loading data') with gfile.GFile(_TRANSITION_PATH, 'r') as f: transitions = np.load(f) if np.max(transitions) > 1.0: transitions = transitions / 255.0 with gfile.GFile(_SYNTHETIC_TRANSITION_PATH, 'r') as f: synthetic_tran_sitions = np.load(f) if np.max(synthetic_transitions) > 1.0: synthetic_transitions = synthetic_transitions / 255.0 with gfile.GFile(transition_label_path, 'r') as f: captions = pickle.load(f) with gfile.GFile(_SYNTHETIC_TRANSITION_LABEL_PATH, 'r') as f: synthetic_captions = pickle.load(f) with gfile.GFile(vocab_path, 'r') as f: vocab_list = f.readlines() vocab_list = [w[:-1].decode('utf-8') for w in vocab_list] vocab_list = ['eos', 'sos'] + vocab_list v2i, i2v = wv.create_look_up_table(vocab_list) encode_fn = wv.encode_text_with_lookup_table(v2i) decode_fn = wv.decode_with_lookup_table(i2v) encoded_captions = [] for all_cp in captions: for cp in all_cp: cp = 'sos ' + cp + ' eos' encoded_captions.append(np.array(encode_fn(cp))) synthetic_encoded_captions = [] for all_cp in synthetic_captions: for cp in all_cp: cp = 'sos ' + cp + ' eos' synthetic_encoded_captions.append(np.array(encode_fn(cp))) all_caption_n = len(encoded_captions) all_synthetic_caption_n = len(synthetic_encoded_captions) encoded_captions = np.array(encoded_captions) encoded_captions = pad_to_max_length(encoded_captions, max_l=15) synthetic_encoded_captions = np.array(synthetic_encoded_captions) synthetic_encoded_captions = pad_to_max_length(synthetic_encoded_captions, max_l=15) obs_idx, caption_idx = [], [] curr_caption_idx = 0 for i, _ in enumerate(transitions): for cp in captions[i]: obs_idx.append(i) caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_caption_n synthetic_obs_idx, synthetic_caption_idx = [], [] curr_caption_idx = 0 for i, _ in enumerate(synthetic_transitions): for cp in synthetic_captions[i]: synthetic_obs_idx.append(i) synthetic_caption_idx.append(curr_caption_idx) curr_caption_idx += 1 assert curr_caption_idx == all_synthetic_caption_n obs_idx = np.array(obs_idx) caption_idx = np.array(caption_idx) all_idx = np.arange(len(caption_idx)) train_idx = all_idx[:int(len(all_idx) * 0.8)] test_idx = all_idx[int(len(all_idx) * 0.8):] print('Number of training examples: {}'.format(len(train_idx))) print('Number of test examples: {}\n'.format(len(test_idx))) synthetic_obs_idx = np.array(synthetic_obs_idx) synthetic_caption_idx = np.array(synthetic_caption_idx) synthetic_all_idx = np.arange(len(synthetic_caption_idx)) synthetic_train_idx = synthetic_all_idx[:int(len(synthetic_all_idx) * 0.8)] synthetic_test_idx = synthetic_all_idx[int(len(synthetic_all_idx) * 0.8):] print('Number of synthetic training examples: {}'.format( len(synthetic_train_idx))) print('Number of synthetic test examples: {}\n'.format( len(synthetic_test_idx))) ############################################################################## ############################# Training Setup ################################# ############################################################################## embedding_dim = 32 units = 64 vocab_size = len(vocab_list) batch_size = 64 max_sequence_length = 15 encoder_config = {'name': 'image', 'embedding_dim': 32} decoder_config = { 'name': 'attention', 'word_embedding_dim': 64, 'hidden_units': 256, 'vocab_size': len(vocab_list), } encoder = get_captioning_encoder(encoder_config) decoder = get_captioning_decoder(decoder_config) optimizer = tf.keras.optimizers.Adam() loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction='none') def loss_function(real, pred, sos_symbol=1): mask = tf.math.logical_not(tf.math.equal(real, sos_symbol)) loss_ = loss_object(real, pred) mask = tf.cast(mask, dtype=loss_.dtype) loss_ *= mask return tf.reduce_mean(loss_) @tf.function def train_step(input_tensor, target): """Traing on a batch of data.""" loss = 0 # initializing the hidden state for each batch # because the captions are not related from image to image hidden = decoder.reset_state(batch_size=target.shape[0]) dec_input = tf.expand_dims([1] * target.shape[0], 1) with tf.GradientTape() as tape: features = encoder(input_tensor, training=True) for i in range(1, target.shape[1]): # passing the features through the decoder predictions, hidden, _ = decoder(dec_input, features, hidden, training=True) loss += loss_function(target[:, i], predictions) # using teacher forcing dec_input = tf.expand_dims(target[:, i], 1) total_loss = (loss / int(target.shape[1])) trainable_variables = encoder.trainable_variables + decoder.trainable_variables gradients = tape.gradient(loss, trainable_variables) optimizer.apply_gradients(zip(gradients, trainable_variables)) return loss, total_loss @tf.function def evaluate_batch(input_tensor, target): """Evaluate loss on a batch of data.""" loss = 0 # initializing the hidden state for each batch # because the captions are not related from image to image hidden = decoder.reset_state(batch_size=target.shape[0]) dec_input = tf.expand_dims([1] * target.shape[0], 1) features = encoder(input_tensor, training=False) for i in range(1, target.shape[1]): # passing the features through the decoder predictions, hidden, _ = decoder(dec_input, features, hidden, training=False) loss += loss_function(target[:, i], predictions) # using teacher forcing dec_input = tf.expand_dims(target[:, i], 1) total_loss = (loss / int(target.shape[1])) return total_loss ############################################################################## ############################# Training Loop ################################## ############################################################################## print('Start training...\n') start_epoch = 0 if FLAGS.save_dir: checkpoint_path = FLAGS.save_dir ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5) if ckpt_manager.latest_checkpoint: start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1]) epochs = 400 step_per_epoch = int(len(captions) / batch_size) * 10 previous_best = 100. mixing_ratio = 0.4 syn_bs = int(batch_size * 2 * mixing_ratio) true_bs = int(batch_size * 2 * (1 - mixing_ratio)) for epoch in range(start_epoch, epochs): start = time.time() total_loss = 0 for batch in range(step_per_epoch): batch_idx = np.random.choice(train_idx, size=true_bs) synthetic_batch_idx = np.random.choice(synthetic_train_idx, size=syn_bs) input_tensor = transitions[obs_idx[batch_idx], :] synthetic_input_tensor = synthetic_transitions[ synthetic_obs_idx[synthetic_batch_idx], :] input_tensor = np.concatenate( [input_tensor, synthetic_input_tensor], axis=0) input_tensor = encoder.preprocess(input_tensor) target = encoded_captions[caption_idx[batch_idx]] sythetic_target = synthetic_encoded_captions[ synthetic_caption_idx[synthetic_batch_idx]] target = np.concatenate([target, sythetic_target], axis=0) batch_loss, t_loss = train_step(input_tensor, target) total_loss += t_loss if batch % 100 == 0: print('Epoch {} Batch {} Loss {:.4f}'.format( epoch + 1, batch, batch_loss.numpy() / int(target.shape[1]))) if epoch % 5 == 0 and FLAGS.save_dir: test_total_loss = 0 for batch in range(3): batch_idx = np.clip( np.arange(true_bs) + batch * true_bs, 0, 196) idx = test_idx[batch_idx] input_tensor = transitions[obs_idx[idx], :] target = encoded_captions[caption_idx[idx]] t_loss = evaluate_batch(input_tensor, target) test_total_loss += t_loss batch_idx = np.arange(syn_bs) + batch * syn_bs idx = synthetic_test_idx[batch_idx] input_tensor = synthetic_transitions[synthetic_obs_idx[idx], :] target = synthetic_encoded_captions[synthetic_caption_idx[idx]] t_loss = evaluate_batch(input_tensor, target) test_total_loss += t_loss test_total_loss /= 6. if test_total_loss < previous_best: previous_best = test_total_loss ckpt_manager.save(checkpoint_number=epoch) print('Epoch {} | Loss {:.6f} | Val loss {:.6f}'.format( epoch + 1, total_loss / step_per_epoch, previous_best)) print('Time taken for 1 epoch {:.6f} sec\n'.format(time.time() - start)) if epoch % 20 == 0: total_loss = 0 for batch in range(len(test_idx) // batch_size): batch_idx = np.arange(batch_size) + batch * batch_size idx = test_idx[batch_idx] input_tensor = transitions[obs_idx[idx], :] target = encoded_captions[caption_idx[idx]] # input_tensor = input_tensor[:, 0] - input_tensor[:, 1] t_loss = evaluate_batch(input_tensor, target) total_loss += t_loss print('====================================================') print('Test Loss {:.6f}'.format(total_loss / (len(test_idx) // batch_size))) print('====================================================\n')
def load_and_train(): """Load data from file and return checkpoints from training.""" simulation_path = f'{FLAGS.data_path}/{FLAGS.simulation_dir}' with gfile.GFile(f'{simulation_path}/{FLAGS.data_file}', 'r') as f: df = pd.read_csv(f) # Split this into train and validate rng = np.random.default_rng(FLAGS.data_seed) users = np.unique(df['user']) users = rng.permutation(users) n_users = users.shape[0] n_train_users = int(n_users / 2) users_train = users[:n_train_users] users_val = users[n_train_users:] assert users_val.shape[0] + users_train.shape[0] == n_users df_tr = df.query('user in @users_train').copy() df_val = df.query('user in @users_val').copy() a_tr = df_tr['rec'].to_numpy() m_tr = df_tr[['diversity', 'rating']].to_numpy() y_tr = df_tr['ltr'].to_numpy() t_tr = np.ones_like(a_tr) a_val = df_val['rec'].to_numpy() m_val = df_val[['diversity', 'rating']].to_numpy() y_val = df_val['ltr'].to_numpy() t_val = np.ones_like(a_val) model = train_proxy.LogisticReg() data_tr = { 'a': a_tr, 'm': m_tr, 'y': y_tr, 't': t_tr, } data_val = { 'a': a_val, 'm': m_val, 'y': y_val, 't': t_val, } init_params = train_proxy.initialize_params(model, mdim=2, seed=FLAGS.seed) loss_tr = train_proxy.make_loss_func(model, data_tr, erm_weight=FLAGS.erm_weight, bias_lamb=FLAGS.bias_lamb, bias_norm=FLAGS.bias_norm) loss_val = train_proxy.make_loss_func(model, data_val, erm_weight=FLAGS.erm_weight, bias_lamb=FLAGS.bias_lamb, bias_norm=FLAGS.bias_norm) _, checkpoints = train_proxy.train(loss_tr, init_params, validation_loss=loss_val, lr=FLAGS.learning_rate, nsteps=FLAGS.nsteps, tol=FLAGS.tol, verbose=True, log=True) return checkpoints