def restore_config(FLAGS, experiment_id): try: data_root = os.environ["COSE_DATA_DIR"] log_dir = os.environ["COSE_LOG_DIR"] eval_dir = os.environ["COSE_EVAL_DIR"] gdrive_key = os.environ["GDRIVE_API_KEY"] except KeyError: if FLAGS.data_dir is None or FLAGS.eval_dir is None or FLAGS.experiment_dir is None: raise Exception( "Either environment variables or FLAGs must be set.") data_root = FLAGS.data_dir log_dir = FLAGS.experiment_dir eval_dir = FLAGS.eval_dir gdrive_key = FLAGS.gdrive_api_key # Check if the experiment directory already exists. model_dir_query = glob.glob(os.path.join(log_dir, experiment_id + "*")) if not model_dir_query: raise ModelNotFoundError # Load experiment config. model_dir = model_dir_query[0] config = Configuration.from_json(os.path.join(model_dir, "config.json")) config.experiment.model_dir = model_dir config.experiment.eval_dir = os.path.join(eval_dir, os.path.basename(model_dir)) if "predictive_model" not in config: raise NotPredictiveModelError __builtin__.print = Print(os.path.join(model_dir, "log.txt")) # Overload print. print("Loading from " + config.experiment.model_dir) if not isinstance(config.data.data_tfrecord_fname, list): config.data.data_tfrecord_fname = [config.data.data_tfrecord_fname] data_path = [ os.path.join(data_root, config.data.data_name, "{}", dp) for dp in config.data.data_tfrecord_fname ] config.data.train_data_path = [dp.format("training") for dp in data_path] config.data.valid_data_path = [dp.format("validation") for dp in data_path] config.data.test_data_path = [dp.format("test") for dp in data_path] config.data.meta_data_path = os.path.join(data_root, config.data.data_name, config.data.data_meta_fname) config.experiment.pretrained_emb_dir = None if config.experiment.get("pretrained_emb_id", None) is not None: config.experiment.pretrained_emb_dir = glob.glob( os.path.join(log_dir, config.experiment.pretrained_emb_id + "-*"))[0] config.gdrive.credential = gdrive_key return config
def get_config(FLAGS, experiment_id=None): """Defines the default configuration.""" experiment_id = FLAGS.experiment_id or experiment_id config = Configuration() config.experiment = ExperimentConfig( comment=FLAGS.comment, tag="", # Set automatically. model_dir=None, # Set automatically. eval_dir=None, # Set automatically. id=experiment_id, max_epochs=None, max_steps=200000, log_frequency=100, eval_steps=500, checkpoint_frequency=500, grad_clip_norm=FLAGS.grad_clip_norm if FLAGS.grad_clip_value <= 0 else 0, grad_clip_value=FLAGS.grad_clip_value, pretrained_emb_id=FLAGS.pretrained_emb_id) config.experiment.learning_rate = AttrDict( name=FLAGS.learning_rate_type, initial_learning_rate=FLAGS.learning_rate, ) if FLAGS.learning_rate_type == "transformer": config.experiment.learning_rate.d_model = FLAGS.transformer_dmodel config.experiment.learning_rate.warmup_steps = 4000 config.data = DataConfig( data_dir=FLAGS.data_dir, data_name=FLAGS.data_name, data_tfrecord_fname=C.DATASET_MAP[FLAGS.data_name] ["data_tfrecord_fname"], data_meta_fname=C.DATASET_MAP[FLAGS.data_name][FLAGS.metadata_type], pp_to_origin="position" in FLAGS.metadata_type, pp_relative_pos="velocity" in FLAGS.metadata_type, normalize=not FLAGS.skip_normalization, batch_size=FLAGS.batch_size, max_length_threshold=201, mask_pen=FLAGS.mask_encoder_pen, resampling_factor=FLAGS.resampling_factor, t_drop_ratio=FLAGS.t_drop_ratio, gt_targets=FLAGS.gt_targets, scale_factor=FLAGS.scale_factor, affine_prob=FLAGS.affine_prob, reverse_prob=FLAGS.reverse_prob, n_t_samples=FLAGS.n_t_samples, int_t_samples=FLAGS.int_t_samples, concat_t_inputs=FLAGS.concat_t_inputs, rdp_dataset=FLAGS.rdp_dataset, rdp_didi_pp=FLAGS.rdp_didi_pp, pos_noise_factor=FLAGS.pos_noise_factor, ) config.gdrive = AttrDict( credential=None, # Set automatically below. workbook= None, # Set your workbook ID (see https://github.com/emreaksan/glogger) sheet=FLAGS.data_name, ) # Embedding model. if FLAGS.encoder_model == "rnn": config.encoder = AttrDict( name="rnn", cell_units=FLAGS.encoder_rnn_units, cell_layers=FLAGS.encoder_rnn_layers, cell_type=FLAGS.encoder_cell_type, bidirectional_encoder=FLAGS.bidirectional_encoder, rec_dropout_rate=FLAGS.encoder_rdropout) elif FLAGS.encoder_model == "transformer": config.encoder = AttrDict( name="transformer", layers=FLAGS.transformer_layers, heads=FLAGS.transformer_heads, d_model=FLAGS.transformer_dmodel, hidden_units=FLAGS.transformer_hidden_units, dropout_rate=FLAGS.transformer_dropout, pos_encoding=config.data.max_length_threshold if FLAGS.transformer_pos_encoding else 0, scale=FLAGS.transformer_scale, autoregressive=not FLAGS.bidirectional_encoder) else: err_unknown_type(FLAGS.encoder_model) config.embedding = AttrDict( latent_units=FLAGS.latent_units, use_vae=FLAGS.use_vae, ) if FLAGS.decoder_model == "rnn": config.decoder = AttrDict( name="rnn", cell_units=FLAGS. encoder_rnn_units, # Using the same hyper-param with the encoder. cell_layers=FLAGS.encoder_rnn_layers, cell_type=FLAGS.encoder_cell_type, dropout_rate=FLAGS.decoder_dropout, dynamic_h0=FLAGS.decoder_dynamic_h0, repeat_vae_sample=FLAGS.repeat_vae_sample, autoregressive=FLAGS.decoder_autoregressive, ) target_key_pen = "pen" target_key_stroke = "stroke" elif FLAGS.decoder_model == "t_emb": config.decoder = AttrDict( name="t_emb", layers=FLAGS.decoder_layers, hidden_units=FLAGS.decoder_hidden_units, activation=FLAGS.decoder_activation, dropout_rate=FLAGS.decoder_dropout, t_frequency_channels=FLAGS.t_frequency_channels, regularizer_weight=FLAGS.reg_dec_weight) target_key_pen = C.TARGET_T_PEN target_key_stroke = C.TARGET_T_STROKE else: err_unknown_type(FLAGS.decoder_model) # Predictive model. if FLAGS.predictive_model == "rnn": config.predictive_model = AttrDict( name="rnn", output_size=config.embedding.latent_units, cell_units=FLAGS.predictive_rnn_units, cell_layers=FLAGS.predictive_rnn_layers, cell_type=FLAGS.predictive_cell_type, activation=C.RELU, use_start_pos=FLAGS.use_start_pos, use_end_pos=FLAGS.use_end_pos, stop_predictive_grad=FLAGS.stop_predictive_grad, num_predictive_inputs=FLAGS.num_pred_inputs, pred_input_type=FLAGS.pred_input_type, ) elif FLAGS.predictive_model == "transformer": config.predictive_model = AttrDict( name="transformer", output_size=config.embedding.latent_units, layers=FLAGS.p_transformer_layers, heads=FLAGS.p_transformer_heads, d_model=FLAGS.p_transformer_dmodel, latent_units=FLAGS.latent_units, hidden_units=FLAGS.p_transformer_hidden_units, dropout_rate=FLAGS.p_transformer_dropout, pos_encoding=FLAGS.p_transformer_pos_encoding, scale=FLAGS.p_transformer_scale, use_start_pos=FLAGS.use_start_pos, use_end_pos=FLAGS.use_end_pos, stop_predictive_grad=FLAGS.stop_predictive_grad, num_predictive_inputs=FLAGS.num_pred_inputs, pred_input_type=FLAGS.pred_input_type, pooling_layer=FLAGS.pooling_layer, ) else: err_unknown_type(FLAGS.predictive_model) # Sharing flags with the predictive model. if FLAGS.position_model == "rnn": config.position_model = AttrDict( name="rnn", output_size=2, cell_units=FLAGS.predictive_rnn_units, cell_layers=FLAGS.predictive_rnn_layers, cell_type=FLAGS.predictive_cell_type, activation=C.RELU, ) elif FLAGS.position_model == "transformer": config.position_model = AttrDict( name="transformer", output_size=2, layers=FLAGS.p_transformer_layers, heads=FLAGS.p_transformer_heads, d_model=FLAGS.p_transformer_dmodel, hidden_units=FLAGS.p_transformer_hidden_units, dropout_rate=FLAGS.p_transformer_dropout, pos_encoding=FLAGS.p_transformer_pos_encoding, scale=FLAGS.p_transformer_scale, ) elif FLAGS.position_model is None: config.position_model = None else: err_unknown_type(FLAGS.position_model) # Loss config.loss = AttrDict() stroke_ = LossConfig(loss_type=FLAGS.stroke_loss, num_components=20, target_key=target_key_stroke, out_key="stroke_logits", reduce_type=C.R_MEAN_STEP) pen_ = LossConfig(eval_only=FLAGS.disable_pen_loss, loss_type=C.NLL_CENT_BINARY, target_key=target_key_pen, out_key="pen_logits", reduce_type=C.R_MEAN_STEP) ink_loss = AttrDict(pen=pen_, stroke=stroke_, prefix="reconstruction") if FLAGS.use_vae: ink_loss.embedding_kld = LossConfig( loss_type=FLAGS.kld_type, # C.KLD_STANDARD or C.KLD_STANDARD_NORM target_key=None, out_key="embedding", reduce_type=C.R_MEAN_STEP) ink_loss.embedding_kld.weight = FLAGS.kld_weight if FLAGS.kld_increment > 0: ink_loss.embedding_kld.weight = dict(type="linear_decay", values=[ FLAGS.kld_start, FLAGS.kld_weight, FLAGS.kld_increment ]) if FLAGS.reg_emb_weight > 0: ink_loss.embedding_l2 = LossConfig(loss_type=C.SNORM_L2, target_key=None, out_key="embedding", reduce_type=C.R_MEAN_STEP, weight=FLAGS.reg_emb_weight) embedding_pred = LossConfig(loss_type=FLAGS.embedding_loss, num_components=FLAGS.embedding_gmm_components, target_key="target", out_key="prediction", reduce_type=C.R_MEAN_STEP) position_pred = LossConfig(loss_type=FLAGS.embedding_loss, num_components=FLAGS.embedding_gmm_components, target_key="target", out_key="prediction", reduce_type=C.R_MEAN_STEP) # embedding_pred.weight = FLAGS.kld_weight # embedding_pred.weight = dict( # type="linear_decay", # values=[FLAGS.kld_start, FLAGS.kld_weight, FLAGS.kld_increment]) config.loss = AttrDict( ink=ink_loss, predicted_embedding=AttrDict(predicted_embedding=embedding_pred), predicted_ink=copy.deepcopy(ink_loss)) if config.position_model is not None: config.loss.predicted_pos = AttrDict(predicted_pos=position_pred) # No kld loss on the predicted ink. if "embedding_kld" in config.loss.predicted_ink: del config.loss.predicted_ink["embedding_kld"] config.loss.apply_predicted_embedding = FLAGS.loss_predicted_embedding config.loss.apply_predicted_ink = FLAGS.loss_predicted_ink config.loss.apply_reconstructed_ink = FLAGS.loss_reconstructed_ink try: data_root = os.environ["COSE_DATA_DIR"] log_dir = os.environ["COSE_LOG_DIR"] eval_dir = os.environ["COSE_EVAL_DIR"] gdrive_key = os.environ["GDRIVE_API_KEY"] except KeyError: if FLAGS.data_dir is None or FLAGS.eval_dir is None or FLAGS.experiment_dir is None: raise Exception( "Either environment variables or FLAGs must be set.") data_root = FLAGS.data_dir log_dir = FLAGS.experiment_dir eval_dir = FLAGS.eval_dir gdrive_key = FLAGS.gdrive_api_key # Check if the experiment directory already exists. model_dir_query = glob.glob( os.path.join(log_dir, config.experiment.id + "*")) if model_dir_query: model_dir = model_dir_query[0] __builtin__.print = Print(os.path.join(model_dir, "log.txt")) # Overload print. # Load experiment config. config = config.from_json(os.path.join(model_dir, "config.json")) config.experiment.model_dir = model_dir config.experiment.eval_dir = os.path.join(eval_dir, os.path.basename(model_dir)) if "predictive_model" not in config: raise NotPredictiveModelError print("Loading from " + config.experiment.model_dir) else: config.experiment.tag = build_experiment_name(config) model_dir_name = config.experiment.id + "-" + config.experiment.tag config.experiment.model_dir = os.path.join(log_dir, model_dir_name) config.experiment.eval_dir = os.path.join(eval_dir, model_dir_name) os.mkdir(config.experiment.model_dir) # Create experiment directory __builtin__.print = Print( os.path.join(config.experiment.model_dir, "log.txt")) # Overload print. print("Saving to " + config.experiment.model_dir) if not isinstance(config.data.data_tfrecord_fname, list): config.data.data_tfrecord_fname = [config.data.data_tfrecord_fname] data_path = [ os.path.join(data_root, config.data.data_name, "{}", dp) for dp in config.data.data_tfrecord_fname ] config.data.train_data_path = [dp.format("training") for dp in data_path] config.data.valid_data_path = [dp.format("validation") for dp in data_path] config.data.test_data_path = [dp.format("test") for dp in data_path] config.data.meta_data_path = os.path.join(data_root, config.data.data_name, config.data.data_meta_fname) config.experiment.pretrained_emb_dir = None if config.experiment.get("pretrained_emb_id", None) is not None: config.experiment.pretrained_emb_dir = glob.glob( os.path.join(log_dir, config.experiment.pretrained_emb_id + "-*"))[0] config.gdrive.credential = gdrive_key if FLAGS.gdrive_api_key == None: config.gdrive = None config.dump(config.experiment.model_dir) return config