示例#1
0
def main():
    parser = argparse.ArgumentParser(
        description="LSTM VAE Agents: ST-GS Language Emergence.")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--parent_folder",
                        type=str,
                        help="folder to save into.",
                        default="")
    parser.add_argument("--symbolic", action="store_true", default=False)
    parser.add_argument("--use_cuda", action="store_true", default=False)
    parser.add_argument("--dataset",
                        type=str,
                        choices=[
                            "Sort-of-CLEVR",
                            "tiny-Sort-of-CLEVR",
                            "XSort-of-CLEVR",
                            "tiny-XSort-of-CLEVR",
                            "dSprites",
                        ],
                        help="dataset to train on.",
                        default="dSprites")
    parser.add_argument("--arch",
                        type=str,
                        choices=[
                            "CNN",
                            "CNN3x3",
                            "BN+CNN",
                            "BN+CNN3x3",
                            "BN+BetaVAE3x3",
                            "BN+Coord2CNN3x3",
                            "BN+Coord4CNN3x3",
                        ],
                        help="model architecture to train",
                        default="BN+BetaVAE3x3")
    parser.add_argument(
        "--graphtype",
        type=str,
        choices=[
            "straight_through_gumbel_softmax", "reinforce",
            "baseline_reduced_reinforce", "normalized_reinforce",
            "baseline_reduced_normalized_reinforce", "max_entr_reinforce",
            "baseline_reduced_normalized_max_entr_reinforce",
            "argmax_reinforce", "obverter"
        ],
        help=
        "type of graph to use during training of the speaker and listener.",
        default="straight_through_gumbel_softmax")
    parser.add_argument("--max_sentence_length", type=int, default=20)
    parser.add_argument("--vocab_size", type=int, default=100)
    parser.add_argument("--optimizer_type",
                        type=str,
                        choices=["adam", "sgd"],
                        default="adam")
    parser.add_argument("--agent_loss_type",
                        type=str,
                        choices=[
                            "Hinge",
                            "NLL",
                            "CE",
                            "BCE",
                        ],
                        default="BCE")
    parser.add_argument("--agent_type",
                        type=str,
                        choices=[
                            "Baseline",
                        ],
                        default="Baseline")
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--epoch", type=int, default=10000)
    parser.add_argument("--metric_epoch_period", type=int, default=20)
    parser.add_argument("--dataloader_num_worker", type=int, default=4)
    parser.add_argument("--metric_fast", action="store_true", default=False)
    parser.add_argument("--batch_size", type=int, default=2)
    parser.add_argument("--mini_batch_size", type=int, default=128)
    parser.add_argument("--dropout_prob", type=float, default=0.0)
    parser.add_argument("--emb_dropout_prob", type=float, default=0.8)
    parser.add_argument("--nbr_experience_repetition", type=int, default=1)
    parser.add_argument("--nbr_train_dataset_repetition", type=int, default=1)
    parser.add_argument("--nbr_test_dataset_repetition", type=int, default=1)
    parser.add_argument("--nbr_test_distractors", type=int, default=0)
    parser.add_argument("--nbr_train_distractors", type=int, default=0)
    parser.add_argument("--resizeDim",
                        default=32,
                        type=int,
                        help="input image resize")
    parser.add_argument("--shared_architecture",
                        action="store_true",
                        default=False)
    parser.add_argument("--with_baseline", action="store_true", default=False)
    parser.add_argument("--homoscedastic_multitasks_loss",
                        action="store_true",
                        default=False)
    parser.add_argument("--with_speaker_entropy_regularization",
                        action="store_true",
                        default=False)
    parser.add_argument("--use_curriculum_nbr_distractors",
                        action="store_true",
                        default=False)
    parser.add_argument("--use_feat_converter",
                        action="store_true",
                        default=False)
    parser.add_argument("--descriptive", action="store_true", default=False)
    parser.add_argument("--egocentric", action="store_true", default=False)
    parser.add_argument("--egocentric_tr_degrees", type=int, default=25)
    parser.add_argument("--egocentric_tr_xy", type=float, default=0.0625)
    parser.add_argument("--distractor_sampling",
                        type=str,
                        choices=[
                            "uniform",
                            "similarity-0.98",
                            "similarity-0.90",
                            "similarity-0.75",
                        ],
                        default="uniform")
    # Obverter Hyperparameters:
    parser.add_argument("--use_sentences_one_hot_vectors",
                        action="store_true",
                        default=False)
    parser.add_argument("--differentiable", action="store_true", default=False)
    parser.add_argument("--obverter_threshold_to_stop_message_generation",
                        type=float,
                        default=0.95)
    parser.add_argument("--obverter_nbr_games_per_round", type=int, default=4)
    # Cultural Bottleneck:
    parser.add_argument("--iterated_learning_scheme",
                        action="store_true",
                        default=False)
    parser.add_argument("--iterated_learning_period", type=int, default=4)
    parser.add_argument("--iterated_learning_rehearse_MDL",
                        action="store_true",
                        default=False)
    parser.add_argument("--iterated_learning_rehearse_MDL_factor",
                        type=float,
                        default=1.0)

    # Dataset Hyperparameters:
    parser.add_argument(
        "--train_test_split_strategy",
        type=str,
        choices=[
            "combinatorial2-Y-2-8-X-2-8-Orientation-40-N-Scale-6-N-Shape-3-N",  # Exp : DoRGsFurtherDise interweaved split simple XY normal             
            "combinatorial2-Y-2-S8-X-2-S8-Orientation-40-N-Scale-4-N-Shape-1-N",
            "combinatorial2-Y-32-N-X-32-N-Orientation-5-S4-Scale-1-S3-Shape-3-N",  #Sparse 2 Attributes: Orient.+Scale 64 imgs, 48 train, 16 test
            "combinatorial2-Y-2-S8-X-2-S8-Orientation-40-N-Scale-6-N-Shape-3-N",  # 4x Denser 2 Attributes: 256 imgs, 192 train, 64 test,

            # Heart shape: interpolation:
            "combinatorial2-Y-4-2-X-4-2-Orientation-40-N-Scale-6-N-Shape-3-N",  #Sparse 2 Attributes: X+Y 64 imgs, 48 train, 16 test
            "combinatorial2-Y-2-2-X-2-2-Orientation-40-N-Scale-6-N-Shape-3-N",  #Dense 2 Attributes: X+Y 256 imgs, 192 train, 64 test
            "combinatorial2-Y-8-2-X-8-2-Orientation-10-2-Scale-1-2-Shape-3-N",  #COMB2:Sparser 4 Attributes: 264 test / 120 train
            "combinatorial2-Y-4-2-X-4-2-Orientation-5-2-Scale-1-2-Shape-3-N",  #COMB2:Sparse 4 Attributes: 2112 test / 960 train
            "combinatorial2-Y-2-2-X-2-2-Orientation-2-2-Scale-1-2-Shape-3-N",  #COMB2:Dense 4 Attributes: ? test / ? train
            "combinatorial2-Y-4-2-X-4-2-Orientation-5-2-Scale-6-N-Shape-3-N",  #COMB2 Sparse: 3 Attributes: XYOrientation 256 test / 256 train
            # Heart shape: Extrapolation:
            "combinatorial2-Y-4-S4-X-4-S4-Orientation-40-N-Scale-6-N-Shape-3-N",  #Sparse 2 Attributes: X+Y 64 imgs, 48 train, 16 test
            "combinatorial2-Y-8-S2-X-8-S2-Orientation-10-S2-Scale-1-S3-Shape-3-N",  #COMB2:Sparser 4 Attributes: 264 test / 120 train
            "combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-1-S3-Shape-3-N",  #COMB2:Sparse 4 Attributes: 2112 test / 960 train
            "combinatorial2-Y-2-S8-X-2-S8-Orientation-2-S10-Scale-1-S3-Shape-3-N",  #COMB2:Dense 4 Attributes: ? test / ? train
            "combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-6-N-Shape-3-N",  #COMB2 Sparse: 3 Attributes: XYOrientation 256 test / 256 train

            # Ovale shape:
            "combinatorial2-Y-1-S16-X-1-S16-Orientation-40-N-Scale-6-N-Shape-2-N",  # Denser 2 Attributes X+Y X 16/ Y 16/ --> 256 test / 768 train 
            "combinatorial2-Y-8-S2-X-8-S2-Orientation-10-S2-Scale-1-S3-Shape-2-N",  #COMB2:Sparser 4 Attributes: 264 test / 120 train
            "combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-1-S3-Shape-2-N",  #COMB2:Sparse 4 Attributes: 2112 test / 960 train
            "combinatorial2-Y-2-S8-X-2-S8-Orientation-2-S10-Scale-1-S3-Shape-2-N",  #COMB2:Dense 4 Attributes: ? test / ? train

            #3 Attributes: denser 2 attributes(X+Y) with the sample size of Dense 4 attributes:
            "combinatorial2-Y-1-S16-X-1-S16-Orientation-2-S10-Scale-6-N-Shape-2-N",
            "combinatorial4-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-1-S3-Shape-3-N",  #Sparse 4 Attributes: 192 test / 1344 train
        ],
        help="train/test split strategy",
        # INTER:
        #default="combinatorial2-Y-4-2-X-4-2-Orientation-40-N-Scale-6-N-Shape-3-N")
        # EXTRA:
        #default="combinatorial2-Y-4-S4-X-4-S4-Orientation-40-N-Scale-6-N-Shape-3-N")
        # EXTRA-3:
        #default="combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-6-N-Shape-3-N")
        # INTER-3:
        default="combinatorial2-Y-4-2-X-4-2-Orientation-5-2-Scale-6-N-Shape-3-N"
    )
    parser.add_argument(
        "--fast",
        action="store_true",
        default=False,
        help=
        "Disable the deterministic CuDNN. It is likely to make the computation faster."
    )

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    # VAE Hyperparameters:
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    parser.add_argument("--vae_detached_featout",
                        action="store_true",
                        default=False)

    parser.add_argument("--vae_lambda", type=float, default=1.0)
    parser.add_argument("--vae_use_mu_value",
                        action="store_true",
                        default=False)

    parser.add_argument("--vae_nbr_latent_dim", type=int, default=32)
    parser.add_argument("--vae_decoder_nbr_layer", type=int, default=3)
    parser.add_argument("--vae_decoder_conv_dim", type=int, default=32)

    parser.add_argument("--vae_gaussian", action="store_true", default=False)
    parser.add_argument("--vae_gaussian_sigma", type=float, default=0.25)

    parser.add_argument("--vae_beta", type=float, default=1.0)
    parser.add_argument("--vae_factor_gamma", type=float, default=0.0)

    parser.add_argument("--vae_constrained_encoding",
                        action="store_true",
                        default=False)
    parser.add_argument("--vae_max_capacity", type=float, default=1e3)
    parser.add_argument("--vae_nbr_epoch_till_max_capacity",
                        type=int,
                        default=10)

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    args = parser.parse_args()
    print(args)

    gaussian = args.vae_gaussian
    vae_observation_sigma = args.vae_gaussian_sigma

    vae_beta = args.vae_beta
    factor_vae_gamma = args.vae_factor_gamma

    vae_constrainedEncoding = args.vae_constrained_encoding
    maxCap = args.vae_max_capacity  #1e2
    nbrepochtillmaxcap = args.vae_nbr_epoch_till_max_capacity

    monet_gamma = 5e-1

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    seed = args.seed

    # Following: https://pytorch.org/docs/stable/notes/randomness.html
    torch.manual_seed(seed)
    if hasattr(torch.backends, "cudnn") and not (args.fast):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    np.random.seed(seed)
    random.seed(seed)
    # # Hyperparameters:

    nbr_epoch = args.epoch

    cnn_feature_size = -1  #600 #128 #256 #
    # Except for VAEs...!

    stimulus_resize_dim = args.resizeDim  #64 #28

    normalize_rgb_values = False

    rgb_scaler = 1.0  #255.0
    from ReferentialGym.datasets.utils import ResizeNormalize
    transform = ResizeNormalize(size=stimulus_resize_dim,
                                normalize_rgb_values=normalize_rgb_values,
                                rgb_scaler=rgb_scaler)

    from ReferentialGym.datasets.utils import AddEgocentricInvariance
    ego_inv_transform = AddEgocentricInvariance()

    transform_degrees = args.egocentric_tr_degrees
    transform_translate = (args.egocentric_tr_xy, args.egocentric_tr_xy)

    rg_config = {
        "observability": "partial",
        "max_sentence_length": args.max_sentence_length,
        "nbr_communication_round": 1,
        "nbr_distractors": {
            "train": args.nbr_train_distractors,
            "test": args.nbr_test_distractors
        },
        "distractor_sampling": args.distractor_sampling,
        # Default: use "similarity-0.5"
        # otherwise the emerging language
        # will have very high ambiguity...
        # Speakers find the strategy of uttering
        # a word that is relevant to the class/label
        # of the target, seemingly.
        "descriptive": args.descriptive,
        "descriptive_target_ratio":
        1 - (1 / (args.nbr_train_distractors + 2)),  #0.97, 
        # Default: 1-(1/(nbr_distractors+2)),
        # otherwise the agent find the local minimum
        # where it only predicts "no-target"...
        "object_centric": False,
        "nbr_stimulus": 1,
        "graphtype": args.graphtype,
        "tau0": 0.2,
        "gumbel_softmax_eps": 1e-6,
        "vocab_size": args.vocab_size,
        "symbol_embedding_size": 256,  #64
        "agent_architecture": args.
        arch,  #"CoordResNet18AvgPooled-2", #"BetaVAE", #"ParallelMONet", #"BetaVAE", #"CNN[-MHDPA]"/"[pretrained-]ResNet18[-MHDPA]-2"
        "agent_learning":
        "learning",  #"transfer_learning" : CNN"s outputs are detached from the graph...
        "agent_loss_type": args.agent_loss_type,  #"NLL"
        "cultural_pressure_it_period": None,
        "cultural_speaker_substrate_size": 1,
        "cultural_listener_substrate_size": 1,
        "cultural_reset_strategy":
        "oldestL",  # "uniformSL" #"meta-oldestL-SGD"
        "cultural_reset_meta_learning_rate": 1e-3,

        # Obverter"s Cultural Bottleneck:
        "iterated_learning_scheme": args.iterated_learning_scheme,
        "iterated_learning_period": args.iterated_learning_period,
        "iterated_learning_rehearse_MDL": args.iterated_learning_rehearse_MDL,
        "iterated_learning_rehearse_MDL_factor":
        args.iterated_learning_rehearse_MDL_factor,
        "obverter_stop_threshold": args.
        obverter_threshold_to_stop_message_generation,  #0.0 if not in use.
        "obverter_nbr_games_per_round": args.obverter_nbr_games_per_round,
        "obverter_least_effort_loss": False,
        "obverter_least_effort_loss_weights": [1.0 for x in range(0, 10)],
        "batch_size": args.batch_size,
        "dataloader_num_worker": args.dataloader_num_worker,
        "stimulus_depth_dim": 1 if "dSprites" in args.dataset else 3,
        "stimulus_resize_dim": stimulus_resize_dim,
        "learning_rate": args.lr,  #1e-3,
        "adam_eps": 1e-8,
        "dropout_prob": args.dropout_prob,
        "embedding_dropout_prob": args.emb_dropout_prob,
        "with_gradient_clip": False,
        "gradient_clip": 1e0,
        "use_homoscedastic_multitasks_loss":
        args.homoscedastic_multitasks_loss,
        "use_feat_converter": args.use_feat_converter,
        "use_curriculum_nbr_distractors": args.use_curriculum_nbr_distractors,
        "curriculum_distractors_window_size": 25,  #100,
        "unsupervised_segmentation_factor": None,  #1e5
        "nbr_experience_repetition": args.nbr_experience_repetition,
        "with_utterance_penalization": False,
        "with_utterance_promotion": False,
        "utterance_oov_prob":
        0.5,  # Expected penalty of observing out-of-vocabulary words. 
        # The greater this value, the greater the loss/cost.
        "utterance_factor": 1e-2,
        "with_speaker_entropy_regularization":
        args.with_speaker_entropy_regularization,
        "with_listener_entropy_regularization": False,
        "entropy_regularization_factor": -1e3,
        "with_mdl_principle": False,
        "mdl_principle_factor": 5e-2,
        "with_weight_maxl1_loss": False,
        "use_cuda": args.use_cuda,
        "train_transform": transform,
        "test_transform": transform,
    }

    if args.egocentric:
        rg_config["train_transform"] = T.Compose([
            ego_inv_transform,
            T.RandomAffine(degrees=transform_degrees,
                           translate=transform_translate,
                           scale=None,
                           shear=None,
                           resample=False,
                           fillcolor=0), transform
        ])
        rg_config["test_transform"] = T.Compose([
            ego_inv_transform,
            T.RandomAffine(degrees=transform_degrees,
                           translate=transform_translate,
                           scale=None,
                           shear=None,
                           resample=False,
                           fillcolor=0), transform
        ])

    ## Train set:
    train_split_strategy = args.train_test_split_strategy
    test_split_strategy = train_split_strategy

    ## Agent Configuration:
    agent_config = copy.deepcopy(rg_config)
    agent_config["use_cuda"] = rg_config["use_cuda"]
    agent_config["homoscedastic_multitasks_loss"] = rg_config[
        "use_homoscedastic_multitasks_loss"]
    agent_config["use_feat_converter"] = rg_config["use_feat_converter"]
    agent_config["max_sentence_length"] = rg_config["max_sentence_length"]
    agent_config["nbr_distractors"] = rg_config["nbr_distractors"][
        "train"] if rg_config["observability"] == "full" else 0
    agent_config["nbr_stimulus"] = rg_config["nbr_stimulus"]
    agent_config["nbr_communication_round"] = rg_config[
        "nbr_communication_round"]
    agent_config["descriptive"] = rg_config["descriptive"]
    agent_config["gumbel_softmax_eps"] = rg_config["gumbel_softmax_eps"]
    agent_config["agent_learning"] = rg_config["agent_learning"]

    # Obverter:
    agent_config[
        "use_obverter_threshold_to_stop_message_generation"] = args.obverter_threshold_to_stop_message_generation

    agent_config["symbol_embedding_size"] = rg_config["symbol_embedding_size"]

    # Recurrent Convolutional Architecture:
    agent_config["architecture"] = rg_config["agent_architecture"]
    agent_config["decoder_architecture"] = "DCNN"
    if args.symbolic:
        agent_config["decoder_architecture"] = "BN+MLP"

    agent_config["dropout_prob"] = rg_config["dropout_prob"]
    agent_config["embedding_dropout_prob"] = rg_config[
        "embedding_dropout_prob"]

    if "BetaVAE" in agent_config["architecture"]:
        agent_config['VAE_lambda'] = args.vae_lambda
        agent_config['vae_beta'] = args.vae_beta
        agent_config['factor_vae_gamma'] = args.vae_factor_gamma
        agent_config['vae_constrainedEncoding'] = args.vae_constrained_encoding
        agent_config['vae_use_gaussian_observation_model'] = args.vae_gaussian
        agent_config['vae_observation_sigma'] = args.vae_gaussian_sigma
        agent_config['vae_max_capacity'] = args.vae_max_capacity  #1e2
        agent_config[
            'vae_nbr_epoch_till_max_capacity'] = args.vae_nbr_epoch_till_max_capacity

        agent_config['vae_decoder_conv_dim'] = 32
        agent_config['vae_decoder_nbr_layer'] = 3
        agent_config['vae_nbr_latent_dim'] = args.vae_nbr_latent_dim
        agent_config['vae_detached_featout'] = args.vae_detached_featout
        agent_config['vae_use_mu_value'] = args.vae_use_mu_value

        if "BN" in args.arch:
            agent_config["cnn_encoder_channels"] = [
                "BN32", "BN32", "BN64", "BN64"
            ]
        else:
            agent_config["cnn_encoder_channels"] = [32, 32, 64, 64]

        if "3x3" in agent_config["architecture"]:
            agent_config["cnn_encoder_kernels"] = [3, 3, 3, 3]
        elif "7x4x4x3" in agent_config["architecture"]:
            agent_config["cnn_encoder_kernels"] = [7, 4, 4, 3]
        else:
            agent_config["cnn_encoder_kernels"] = [4, 4, 4, 4]
        agent_config["cnn_encoder_strides"] = [2, 2, 2, 2]
        agent_config["cnn_encoder_paddings"] = [1, 1, 1, 1]
        agent_config["cnn_encoder_fc_hidden_units"] = []  #[128,]
        # the last FC layer is provided by the cnn_encoder_feature_dim parameter below...

        # For a fair comparison between CNN an VAEs:
        agent_config["cnn_encoder_feature_dim"] = args.vae_nbr_latent_dim
        #agent_config["cnn_encoder_feature_dim"] = cnn_feature_size
        # N.B.: if cnn_encoder_fc_hidden_units is [],
        # then this last parameter does not matter.
        # The cnn encoder is not topped by a FC network.

        agent_config["cnn_encoder_mini_batch_size"] = args.mini_batch_size
        #agent_config["feat_converter_output_size"] = cnn_feature_size
        agent_config["feat_converter_output_size"] = 256

        if "MHDPA" in agent_config["architecture"]:
            agent_config["mhdpa_nbr_head"] = 4
            agent_config["mhdpa_nbr_rec_update"] = 1
            agent_config["mhdpa_nbr_mlp_unit"] = 256
            agent_config["mhdpa_interaction_dim"] = 128

        agent_config["temporal_encoder_nbr_hidden_units"] = 0
        agent_config["temporal_encoder_nbr_rnn_layers"] = 0
        agent_config["temporal_encoder_mini_batch_size"] = args.mini_batch_size
        agent_config["symbol_processing_nbr_hidden_units"] = agent_config[
            "temporal_encoder_nbr_hidden_units"]
        agent_config["symbol_processing_nbr_rnn_layers"] = 1

        ## Decoder:
        ### CNN:
        if "BN" in agent_config["decoder_architecture"]:
            agent_config["cnn_decoder_channels"] = [
                "BN64", "BN64", "BN32", "BN32"
            ]
        else:
            agent_config["cnn_decoder_channels"] = [64, 64, 32, 32]

        if "3x3" in agent_config["decoder_architecture"]:
            agent_config["cnn_decoder_kernels"] = [3, 3, 3, 3]
        elif "3x4x4x7" in agent_config["decoder_architecture"]:
            agent_config["cnn_decoder_kernels"] = [3, 4, 4, 7]
        else:
            agent_config["cnn_decoder_kernels"] = [4, 4, 4, 4]
        agent_config["cnn_decoder_strides"] = [2, 2, 2, 2]
        agent_config["cnn_decoder_paddings"] = [1, 1, 1, 1]

        ### MLP:
        if "BN" in agent_config["decoder_architecture"]:
            agent_config['mlp_decoder_fc_hidden_units'] = ["BN256", "BN256"]
        else:
            agent_config['mlp_decoder_fc_hidden_units'] = [256, 256]
        agent_config['mlp_decoder_fc_hidden_units'].append(40 * 6)

    else:
        raise NotImplementedError

    save_path = "./"
    if args.parent_folder != '':
        save_path += args.parent_folder + '/'
    save_path += f"{args.dataset}+DualLabeled/"
    if args.symbolic:
        save_path += f"Symbolic/"
    if args.egocentric:
        save_path += f"Egocentric-Rot{args.egocentric_tr_degrees}-XY{args.egocentric_tr_xy}/"
    save_path += f"{nbr_epoch}Ep_Emb{rg_config['symbol_embedding_size']}_CNN{cnn_feature_size}to{args.vae_nbr_latent_dim}"
    if args.shared_architecture:
        save_path += "/shared_architecture"
    save_path += f"Dropout{rg_config['dropout_prob']}_DPEmb{rg_config['embedding_dropout_prob']}"
    save_path += f"_BN_{rg_config['agent_learning']}/"
    save_path += f"{rg_config['agent_loss_type']}"

    if 'dSprites' in args.dataset:
        train_test_strategy = f"-{test_split_strategy}"
        if test_split_strategy != train_split_strategy:
            train_test_strategy = f"/train_{train_split_strategy}/test_{test_split_strategy}"
        save_path += f"/dSprites{train_test_strategy}"

    save_path += f"/OBS{rg_config['stimulus_resize_dim']}X{rg_config['stimulus_depth_dim']}C-Rep{rg_config['nbr_experience_repetition']}"

    if rg_config['use_curriculum_nbr_distractors']:
        save_path += f"+W{rg_config['curriculum_distractors_window_size']}Curr"
    if rg_config['with_utterance_penalization']:
        save_path += "+Tau-10-OOV{}PenProb{}".format(
            rg_config['utterance_factor'], rg_config['utterance_oov_prob'])
    if rg_config['with_utterance_promotion']:
        save_path += "+Tau-10-OOV{}ProProb{}".format(
            rg_config['utterance_factor'], rg_config['utterance_oov_prob'])

    if rg_config['with_gradient_clip']:
        save_path += '+ClipGrad{}'.format(rg_config['gradient_clip'])

    if rg_config['with_speaker_entropy_regularization']:
        save_path += 'SPEntrReg{}'.format(
            rg_config['entropy_regularization_factor'])
    if rg_config['with_listener_entropy_regularization']:
        save_path += 'LSEntrReg{}'.format(
            rg_config['entropy_regularization_factor'])

    if rg_config['iterated_learning_scheme']:
        save_path += f"-ILM{rg_config['iterated_learning_period']}{'+RehearseMDL{}'.format(rg_config['iterated_learning_rehearse_MDL_factor']) if rg_config['iterated_learning_rehearse_MDL'] else ''}"

    if rg_config['with_mdl_principle']:
        save_path += '-MDL{}'.format(rg_config['mdl_principle_factor'])

    if rg_config['cultural_pressure_it_period'] != 'None':
        save_path += '-S{}L{}-{}-Reset{}'.\
          format(rg_config['cultural_speaker_substrate_size'],
          rg_config['cultural_listener_substrate_size'],
          rg_config['cultural_pressure_it_period'],
          rg_config['cultural_reset_strategy']+str(rg_config['cultural_reset_meta_learning_rate']) if 'meta' in rg_config['cultural_reset_strategy'] else rg_config['cultural_reset_strategy'])

    save_path += '-{}{}CulturalAgent-SEED{}-{}-obs_b{}_minib{}_lr{}-{}-tau0-{}-{}DistrTrain{}Test{}-stim{}-vocab{}over{}_{}{}'.\
      format(
      'ObjectCentric' if rg_config['object_centric'] else '',
      'Descriptive{}'.format(rg_config['descriptive_target_ratio']) if rg_config['descriptive'] else '',
      seed,
      rg_config['observability'],
      rg_config['batch_size'],
      args.mini_batch_size,
      rg_config['learning_rate'],
      rg_config['graphtype'],
      rg_config['tau0'],
      rg_config['distractor_sampling'],
      *rg_config['nbr_distractors'].values(),
      rg_config['nbr_stimulus'],
      rg_config['vocab_size'],
      rg_config['max_sentence_length'],
      rg_config['agent_architecture'],
      f"/{'Detached' if args.vae_detached_featout else ''}beta{vae_beta}-factor{factor_vae_gamma}" if 'BetaVAE' in rg_config['agent_architecture'] else ''
    )

    if 'MONet' in rg_config['agent_architecture'] or 'BetaVAE' in rg_config[
            'agent_architecture']:
        save_path += f"beta{vae_beta}-factor{factor_vae_gamma}-gamma{monet_gamma}-sigma{vae_observation_sigma}" if 'MONet' in rg_config[
            'agent_architecture'] else ''
        save_path += f"CEMC{maxCap}over{nbrepochtillmaxcap}" if vae_constrainedEncoding else ''
        save_path += f"UnsupSeg{rg_config['unsupervised_segmentation_factor']}" if rg_config[
            'unsupervised_segmentation_factor'] is not None else ''
        save_path += f"LossVAECoeff{args.vae_lambda}_{'UseMu' if args.vae_use_mu_value else ''}"

    if rg_config['use_feat_converter']:
        save_path += f"+FEATCONV"

    if rg_config['use_homoscedastic_multitasks_loss']:
        save_path += '+H**o'

    save_path += f"/{args.optimizer_type}/"

    if 'reinforce' in args.graphtype:
        save_path += f'/REINFORCE_EntropyCoeffNeg1m3/UnnormalizedDetLearningSignalHavrylovLoss/NegPG/'

    if 'obverter' in args.graphtype:
        save_path += f"Obverter{args.obverter_threshold_to_stop_message_generation}-{args.obverter_nbr_games_per_round}GPR/DEBUG/"
    else:
        save_path += f"STGS-{args.agent_type}-LSTM-CNN-Agent/"

    save_path += f"Periodic{args.metric_epoch_period}TS+DISComp-{'fast-' if args.metric_fast else ''}/"  #TestArchTanh/"

    save_path += f'DatasetRepTrain{args.nbr_train_dataset_repetition}Test{args.nbr_test_dataset_repetition}'

    rg_config['save_path'] = save_path

    print(save_path)

    from ReferentialGym.utils import statsLogger
    logger = statsLogger(path=save_path, dumpPeriod=100)

    # # Agents
    batch_size = 4
    nbr_distractors = 1 if "partial" in rg_config[
        "observability"] else agent_config["nbr_distractors"]["train"]
    nbr_stimulus = agent_config["nbr_stimulus"]
    obs_shape = [
        nbr_distractors + 1, nbr_stimulus, rg_config["stimulus_depth_dim"],
        rg_config["stimulus_resize_dim"], rg_config["stimulus_resize_dim"]
    ]
    vocab_size = rg_config["vocab_size"]
    max_sentence_length = rg_config["max_sentence_length"]

    if "obverter" in args.graphtype:
        from ReferentialGym.agents import DifferentiableObverterAgent
        speaker = DifferentiableObverterAgent(
            kwargs=agent_config,
            obs_shape=obs_shape,
            vocab_size=vocab_size,
            max_sentence_length=max_sentence_length,
            agent_id="s0",
            logger=logger,
            use_sentences_one_hot_vectors=args.use_sentences_one_hot_vectors,
            differentiable=args.differentiable)
    else:
        if "Baseline" in args.agent_type:
            from ReferentialGym.agents import LSTMCNNSpeaker
            speaker = LSTMCNNSpeaker(kwargs=agent_config,
                                     obs_shape=obs_shape,
                                     vocab_size=vocab_size,
                                     max_sentence_length=max_sentence_length,
                                     agent_id="s0",
                                     logger=logger)
        elif "EoSPriored" in args.agent_type:
            from ReferentialGym.agents import EoSPrioredLSTMCNNSpeaker
            speaker = EoSPrioredLSTMCNNSpeaker(
                kwargs=agent_config,
                obs_shape=obs_shape,
                vocab_size=vocab_size,
                max_sentence_length=max_sentence_length,
                agent_id="s0",
                logger=logger)
    print("Speaker:", speaker)

    listener_config = copy.deepcopy(agent_config)
    if args.shared_architecture:
        listener_config["cnn_encoder"] = speaker.cnn_encoder
    listener_config["nbr_distractors"] = rg_config["nbr_distractors"]["train"]
    batch_size = 4
    nbr_distractors = listener_config["nbr_distractors"]
    nbr_stimulus = listener_config["nbr_stimulus"]
    obs_shape = [
        nbr_distractors + 1, nbr_stimulus, rg_config["stimulus_depth_dim"],
        rg_config["stimulus_resize_dim"], rg_config["stimulus_resize_dim"]
    ]
    vocab_size = rg_config["vocab_size"]
    max_sentence_length = rg_config["max_sentence_length"]

    if "obverter" in args.graphtype:
        raise NotImplementedError
    else:
        if args.symbolic:
            from ReferentialGym.agents import LSTMMLPGenerativeListener
            listener = LSTMMLPGenerativeListener(
                kwargs=listener_config,
                obs_shape=obs_shape,
                vocab_size=vocab_size,
                max_sentence_length=max_sentence_length,
                agent_id="l0",
                logger=logger)
        else:
            from ReferentialGym.agents import LSTMCNNGenerativeListener
            listener = LSTMCNNGenerativeListener(
                kwargs=listener_config,
                obs_shape=obs_shape,
                vocab_size=vocab_size,
                max_sentence_length=max_sentence_length,
                agent_id="l0",
                logger=logger)

    if args.symbolic:
        assert args.agent_loss_type.lower() == 'ce'
        listener.input_stream_ids["listener"][
            "target_output"] = "current_dataloader:sample:speaker_exp_latents"

    print("Listener:", listener)

    # # Dataset:
    need_dict_wrapping = {}

    if "dSprites" in args.dataset:
        root = "./datasets/dsprites-dataset"
        train_dataset = ReferentialGym.datasets.dSpritesDataset(
            root=root,
            train=True,
            transform=rg_config["train_transform"],
            split_strategy=train_split_strategy)
        test_dataset = ReferentialGym.datasets.dSpritesDataset(
            root=root,
            train=False,
            transform=rg_config["test_transform"],
            split_strategy=test_split_strategy)
    else:
        raise NotImplementedError

    ## Modules:
    modules = {}

    from ReferentialGym import modules as rg_modules

    # Population:
    population_handler_id = "population_handler_0"
    population_handler_config = rg_config
    population_handler_stream_ids = {
        "current_speaker_streams_dict": "modules:current_speaker",
        "current_listener_streams_dict": "modules:current_listener",
        "epoch": "signals:epoch",
        "mode": "signals:mode",
        "global_it_datasample": "signals:global_it_datasample",
    }

    # Current Speaker:
    current_speaker_id = "current_speaker"

    # Current Listener:
    current_listener_id = "current_listener"

    modules[population_handler_id] = rg_modules.build_PopulationHandlerModule(
        id=population_handler_id,
        prototype_speaker=speaker,
        prototype_listener=listener,
        config=population_handler_config,
        input_stream_ids=population_handler_stream_ids)

    modules[current_speaker_id] = rg_modules.CurrentAgentModule(
        id=current_speaker_id, role="speaker")
    modules[current_listener_id] = rg_modules.CurrentAgentModule(
        id=current_listener_id, role="listener")

    homo_id = "homo0"
    homo_config = {"use_cuda": args.use_cuda}
    if args.homoscedastic_multitasks_loss:
        modules[homo_id] = rg_modules.build_HomoscedasticMultiTasksLossModule(
            id=homo_id,
            config=homo_config,
        )

    ## Pipelines:
    pipelines = {}

    # 0) Now that all the modules are known, let us build the optimization module:
    optim_id = "global_optim"
    optim_config = {
        "modules": modules,
        "learning_rate": args.lr,
        "optimizer_type": args.optimizer_type,
        "with_gradient_clip": rg_config["with_gradient_clip"],
        "adam_eps": rg_config["adam_eps"],
    }

    optim_module = rg_modules.build_OptimizationModule(
        id=optim_id,
        config=optim_config,
    )
    modules[optim_id] = optim_module

    grad_recorder_id = "grad_recorder"
    grad_recorder_module = rg_modules.build_GradRecorderModule(
        id=grad_recorder_id)
    modules[grad_recorder_id] = grad_recorder_module

    topo_sim_metric_id = "topo_sim_metric"
    topo_sim_metric_module = rg_modules.build_TopographicSimilarityMetricModule(
        id=topo_sim_metric_id,
        config={
            "parallel_TS_computation_max_workers": 16,
            "epoch_period": args.metric_epoch_period,
            "fast": args.metric_fast,
            "verbose": False,
            "vocab_size": rg_config["vocab_size"],
        })
    modules[topo_sim_metric_id] = topo_sim_metric_module
    """
  inst_coord_metric_id = "inst_coord_metric"
  inst_coord_metric_module = rg_modules.build_InstantaneousCoordinationMetricModule(id=inst_coord_metric_id,
    config = {
      "epoch_period":1,
    }
  )
  modules[inst_coord_metric_id] = inst_coord_metric_module
  """
    """
  dsprites_latent_metric_id = "dsprites_latent_metric"
  dsprites_latent_metric_module = rg_modules.build_dSpritesPerLatentAccuracyMetricModule(id=dsprites_latent_metric_id,
    config = {
      "epoch_period":1,
    }
  )
  modules[dsprites_latent_metric_id] = dsprites_latent_metric_module
  """

    speaker_factor_vae_disentanglement_metric_id = "speaker_factor_vae_disentanglement_metric"
    speaker_factor_vae_disentanglement_metric_input_stream_ids = {
        "model": "modules:current_speaker:ref:ref_agent:cnn_encoder",
        "representations": "modules:current_speaker:ref:ref_agent:features",
        "experiences": "current_dataloader:sample:speaker_experiences",
        "latent_representations":
        "current_dataloader:sample:speaker_exp_latents",
        "latent_values_representations":
        "current_dataloader:sample:speaker_exp_latents_values",
        "indices": "current_dataloader:sample:speaker_indices",
    }
    speaker_factor_vae_disentanglement_metric_module = rg_modules.build_FactorVAEDisentanglementMetricModule(
        id=speaker_factor_vae_disentanglement_metric_id,
        input_stream_ids=
        speaker_factor_vae_disentanglement_metric_input_stream_ids,
        config={
            "epoch_period": args.metric_epoch_period,
            "batch_size": 64,  #5,
            "nbr_train_points": 10000,  #3000,
            "nbr_eval_points": 5000,  #2000,
            "resample": False,
            "threshold": 5e-2,  #0.0,#1.0,
            "random_state_seed": args.seed,
            "verbose": False,
            "active_factors_only": True,
        })
    modules[
        speaker_factor_vae_disentanglement_metric_id] = speaker_factor_vae_disentanglement_metric_module

    listener_factor_vae_disentanglement_metric_id = "listener_factor_vae_disentanglement_metric"
    listener_factor_vae_disentanglement_metric_input_stream_ids = {
        "model": "modules:current_listener:ref:ref_agent:cnn_encoder",
        "representations":
        "modules:current_listener:ref:ref_agent:rnn_outputs",
        "experiences": "current_dataloader:sample:speaker_experiences",
        "latent_representations":
        "current_dataloader:sample:speaker_exp_latents",
        "latent_values_representations":
        "current_dataloader:sample:speaker_exp_latents_values",
        "indices": "current_dataloader:sample:speaker_indices",
    }
    listener_factor_vae_disentanglement_metric_module = rg_modules.build_FactorVAEDisentanglementMetricModule(
        id=listener_factor_vae_disentanglement_metric_id,
        input_stream_ids=
        listener_factor_vae_disentanglement_metric_input_stream_ids,
        config={
            "epoch_period": args.metric_epoch_period,
            "batch_size": 64,  #5,
            "nbr_train_points": 10000,  #3000,
            "nbr_eval_points": 5000,  #2000,
            "resample": False,
            "threshold": 5e-2,  #0.0,#1.0,
            "random_state_seed": args.seed,
            "verbose": False,
            "active_factors_only": True,
        })
    modules[
        listener_factor_vae_disentanglement_metric_id] = listener_factor_vae_disentanglement_metric_module

    logger_id = "per_epoch_logger"
    logger_module = rg_modules.build_PerEpochLoggerModule(id=logger_id)
    modules[logger_id] = logger_module

    pipelines["referential_game"] = [
        population_handler_id, current_speaker_id, current_listener_id
    ]

    pipelines[optim_id] = []
    if args.homoscedastic_multitasks_loss:
        pipelines[optim_id].append(homo_id)
    pipelines[optim_id].append(optim_id)
    """
  # Add gradient recorder module for debugging purposes:
  pipelines[optim_id].append(grad_recorder_id)
  """
    pipelines[optim_id].append(speaker_factor_vae_disentanglement_metric_id)
    pipelines[optim_id].append(listener_factor_vae_disentanglement_metric_id)
    pipelines[optim_id].append(topo_sim_metric_id)
    """
  pipelines[optim_id].append(inst_coord_metric_id)
  pipelines[optim_id].append(dsprites_latent_metric_id)
  """
    pipelines[optim_id].append(logger_id)

    rg_config["modules"] = modules
    rg_config["pipelines"] = pipelines

    dataset_args = {
        "dataset_class": "DualLabeledDataset",
        "modes": {
            "train": train_dataset,
            "test": test_dataset,
        },
        "need_dict_wrapping": need_dict_wrapping,
        "nbr_stimulus": rg_config["nbr_stimulus"],
        "distractor_sampling": rg_config["distractor_sampling"],
        "nbr_distractors": rg_config["nbr_distractors"],
        "observability": rg_config["observability"],
        "object_centric": rg_config["object_centric"],
        "descriptive": rg_config["descriptive"],
        "descriptive_target_ratio": rg_config["descriptive_target_ratio"],
    }

    refgame = ReferentialGym.make(config=rg_config, dataset_args=dataset_args)

    # In[22]:

    refgame.train(nbr_epoch=nbr_epoch, logger=logger, verbose_period=1)

    logger.flush()
def main():
    parser = argparse.ArgumentParser(
        description='LSTM CNN Agents: ST-GS Language Emergence.')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--use_cuda', action='store_true', default=False)
    parser.add_argument('--dataset',
                        type=str,
                        choices=[
                            'Sort-of-CLEVR',
                            'tiny-Sort-of-CLEVR',
                            'XSort-of-CLEVR',
                            'tiny-XSort-of-CLEVR',
                            'dSprites',
                        ],
                        help='dataset to train on.',
                        default='dSprites')
    parser.add_argument('--arch',
                        type=str,
                        choices=[
                            'CNN',
                            'CNN3x3',
                            'BN+CNN',
                            'BN+CNN3x3',
                            'BN+Coord2CNN3x3',
                            'BN+Coord4CNN3x3',
                            'Santoro2017-SoC-CNN',
                            'Santoro2017-CLEVR-CNN',
                            'Santoro2017-CLEVR-CNN3x3',
                            'Santoro2017-CLEVR-CoordCNN3x3',
                            'Santoro2017-CLEVR-EntityPrioredCNN3x3',
                            'Santoro2017-CLEVR-CNN7x4x4x3',
                        ],
                        help='model architecture to train',
                        default="BN+Coord4CNN3x3")
    parser.add_argument(
        '--graphtype',
        type=str,
        choices=[
            'straight_through_gumbel_softmax', 'reinforce',
            'baseline_reduced_reinforce', 'normalized_reinforce',
            'baseline_reduced_normalized_reinforce', 'max_entr_reinforce',
            'baseline_reduced_normalized_max_entr_reinforce',
            'argmax_reinforce', 'obverter'
        ],
        help=
        'type of graph to use during training of the speaker and listener.',
        default='straight_through_gumbel_softmax')
    parser.add_argument('--max_sentence_length', type=int, default=5)
    parser.add_argument('--vocab_size', type=int, default=100)
    parser.add_argument('--optimizer_type',
                        type=str,
                        choices=["adam", "sgd"],
                        default="adam")
    parser.add_argument('--agent_loss_type',
                        type=str,
                        choices=[
                            "Hinge",
                            "NLL",
                            "CE",
                        ],
                        default="Hinge")
    parser.add_argument('--agent_type',
                        type=str,
                        choices=[
                            "Baseline", "EoSPriored", "Transcoding",
                            "TranscodingSpeaker", "TranscodingListener"
                        ],
                        default="TranscodingSpeaker")
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--epoch', type=int, default=10000)
    parser.add_argument('--metric_epoch_period', type=int, default=20)
    parser.add_argument('--dataloader_num_worker', type=int, default=4)
    parser.add_argument('--metric_fast', action='store_true', default=False)
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--mini_batch_size', type=int, default=128)
    parser.add_argument('--dropout_prob', type=float, default=0.0)
    parser.add_argument('--emb_dropout_prob', type=float, default=0.8)
    parser.add_argument('--nbr_experience_repetition', type=int, default=1)
    parser.add_argument('--nbr_train_dataset_repetition', type=int, default=1)
    parser.add_argument('--nbr_test_dataset_repetition', type=int, default=1)
    parser.add_argument('--nbr_test_distractors', type=int, default=63)
    parser.add_argument('--nbr_train_distractors', type=int, default=47)
    parser.add_argument('--resizeDim',
                        default=32,
                        type=int,
                        help='input image resize')
    parser.add_argument('--st_gs_inv_tau0', type=float, default=0.2)
    parser.add_argument('--visual_decoder_nbr_steps', type=int, default=2)
    parser.add_argument('--transcoder_speaker_soft_attention',
                        action='store_true',
                        default=False)
    parser.add_argument('--transcoder_listener_soft_attention',
                        action='store_true',
                        default=False)
    parser.add_argument('--transcoder_st_gs_inv_tau0', type=float, default=0.5)
    parser.add_argument('--transcoder_visual_encoder_use_coord4',
                        action='store_true',
                        default=False)
    parser.add_argument('--shared_architecture',
                        action='store_true',
                        default=False)
    parser.add_argument('--same_head', action='store_true', default=False)
    parser.add_argument('--with_baseline', action='store_true', default=False)
    parser.add_argument('--homoscedastic_multitasks_loss',
                        action='store_true',
                        default=False)
    parser.add_argument('--use_curriculum_nbr_distractors',
                        action='store_true',
                        default=False)
    parser.add_argument('--use_feat_converter',
                        action='store_true',
                        default=False)
    parser.add_argument('--detached_heads', action='store_true', default=False)
    parser.add_argument('--test_id_analogy',
                        action='store_true',
                        default=False)
    parser.add_argument('--distractor_sampling',
                        type=str,
                        choices=[
                            "uniform",
                            "similarity-0.98",
                            "similarity-0.90",
                            "similarity-0.75",
                        ],
                        default="uniform")
    # Obverter Hyperparameters:
    parser.add_argument('--use_sentences_one_hot_vectors',
                        action='store_true',
                        default=False)
    parser.add_argument('--differentiable', action='store_true', default=False)
    parser.add_argument('--obverter_threshold_to_stop_message_generation',
                        type=float,
                        default=0.95)
    parser.add_argument('--obverter_nbr_games_per_round', type=int, default=4)
    # Cultural Bottleneck:
    parser.add_argument('--iterated_learning_scheme',
                        action='store_true',
                        default=False)
    parser.add_argument('--iterated_learning_period', type=int, default=4)
    parser.add_argument('--iterated_learning_rehearse_MDL',
                        action='store_true',
                        default=False)
    parser.add_argument('--iterated_learning_rehearse_MDL_factor',
                        type=float,
                        default=1.0)

    # Dataset Hyperparameters:
    parser.add_argument(
        '--train_test_split_strategy',
        type=str,
        choices=[
            'combinatorial2-Y-2-8-X-2-8-Orientation-40-N-Scale-6-N-Shape-3-N',  # Exp : DoRGsFurtherDise interweaved split simple XY normal             
            'combinatorial2-Y-2-S8-X-2-S8-Orientation-40-N-Scale-4-N-Shape-1-N',
            'combinatorial2-Y-32-N-X-32-N-Orientation-5-S4-Scale-1-S3-Shape-3-N',  #Sparse 2 Attributes: Orient.+Scale 64 imgs, 48 train, 16 test
            'combinatorial2-Y-2-S8-X-2-S8-Orientation-40-N-Scale-6-N-Shape-3-N',  # 4x Denser 2 Attributes: 256 imgs, 192 train, 64 test,

            # Heart shape: interpolation:
            'combinatorial2-Y-4-2-X-4-2-Orientation-40-N-Scale-6-N-Shape-3-N',  #Sparse 2 Attributes: X+Y 64 imgs, 48 train, 16 test
            'combinatorial2-Y-2-2-X-2-2-Orientation-40-N-Scale-6-N-Shape-3-N',  #Dense 2 Attributes: X+Y 256 imgs, 192 train, 64 test
            'combinatorial2-Y-8-2-X-8-2-Orientation-10-2-Scale-1-2-Shape-3-N',  #COMB2:Sparser 4 Attributes: 264 test / 120 train
            'combinatorial2-Y-4-2-X-4-2-Orientation-5-2-Scale-1-2-Shape-3-N',  #COMB2:Sparse 4 Attributes: 2112 test / 960 train
            'combinatorial2-Y-2-2-X-2-2-Orientation-2-2-Scale-1-2-Shape-3-N',  #COMB2:Dense 4 Attributes: ? test / ? train
            # Heart shape: Extrapolation:
            'combinatorial2-Y-4-S4-X-4-S4-Orientation-40-N-Scale-6-N-Shape-3-N',  #Sparse 2 Attributes: X+Y 64 imgs, 48 train, 16 test
            'combinatorial2-Y-8-S2-X-8-S2-Orientation-10-S2-Scale-1-S3-Shape-3-N',  #COMB2:Sparser 4 Attributes: 264 test / 120 train
            'combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-1-S3-Shape-3-N',  #COMB2:Sparse 4 Attributes: 2112 test / 960 train
            'combinatorial2-Y-2-S8-X-2-S8-Orientation-2-S10-Scale-1-S3-Shape-3-N',  #COMB2:Dense 4 Attributes: ? test / ? train
            'combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-6-N-Shape-3-N',  #COMB2 Sparse: 3 Attributes: XYOrientation 256 test / 256 train

            # Ovale shape:
            'combinatorial2-Y-1-S16-X-1-S16-Orientation-40-N-Scale-6-N-Shape-2-N',  # Denser 2 Attributes X+Y X 16/ Y 16/ --> 256 test / 768 train 
            'combinatorial2-Y-8-S2-X-8-S2-Orientation-10-S2-Scale-1-S3-Shape-2-N',  #COMB2:Sparser 4 Attributes: 264 test / 120 train
            'combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-1-S3-Shape-2-N',  #COMB2:Sparse 4 Attributes: 2112 test / 960 train
            'combinatorial2-Y-2-S8-X-2-S8-Orientation-2-S10-Scale-1-S3-Shape-2-N',  #COMB2:Dense 4 Attributes: ? test / ? train

            #3 Attributes: denser 2 attributes(X+Y) with the sample size of Dense 4 attributes:
            'combinatorial2-Y-1-S16-X-1-S16-Orientation-2-S10-Scale-6-N-Shape-2-N',
            'combinatorial4-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-1-S3-Shape-3-N',  #Sparse 4 Attributes: 192 test / 1344 train
        ],
        help='train/test split strategy',
        # INTER:
        default=
        'combinatorial2-Y-4-2-X-4-2-Orientation-40-N-Scale-6-N-Shape-3-N')
    # EXTRA:
    #default='combinatorial2-Y-4-S4-X-4-S4-Orientation-40-N-Scale-6-N-Shape-3-N')
    parser.add_argument(
        '--fast',
        action='store_true',
        default=False,
        help=
        'Disable the deterministic CuDNN. It is likely to make the computation faster.'
    )

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    # VAE Hyperparameters:
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    parser.add_argument('--vae_detached_featout',
                        action='store_true',
                        default=False)

    parser.add_argument('--vae_lambda', type=float, default=1.0)
    parser.add_argument('--vae_use_mu_value',
                        action='store_true',
                        default=False)

    parser.add_argument('--vae_nbr_latent_dim', type=int, default=128)
    parser.add_argument('--vae_decoder_nbr_layer', type=int, default=3)
    parser.add_argument('--vae_decoder_conv_dim', type=int, default=32)

    parser.add_argument('--vae_gaussian', action='store_true', default=False)
    parser.add_argument('--vae_gaussian_sigma', type=float, default=0.25)

    parser.add_argument('--vae_beta', type=float, default=1.0)
    parser.add_argument('--vae_factor_gamma', type=float, default=0.0)

    parser.add_argument('--vae_constrained_encoding',
                        action='store_true',
                        default=False)
    parser.add_argument('--vae_max_capacity', type=float, default=1e3)
    parser.add_argument('--vae_nbr_epoch_till_max_capacity',
                        type=int,
                        default=10)

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    args = parser.parse_args()
    print(args)

    gaussian = args.vae_gaussian
    vae_observation_sigma = args.vae_gaussian_sigma

    vae_beta = args.vae_beta
    factor_vae_gamma = args.vae_factor_gamma

    vae_constrainedEncoding = args.vae_constrained_encoding
    maxCap = args.vae_max_capacity  #1e2
    nbrepochtillmaxcap = args.vae_nbr_epoch_till_max_capacity

    monet_gamma = 5e-1

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    seed = args.seed

    # Following: https://pytorch.org/docs/stable/notes/randomness.html
    torch.manual_seed(seed)
    if hasattr(torch.backends, 'cudnn') and not (args.fast):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    np.random.seed(seed)
    random.seed(seed)
    # # Hyperparameters:

    nbr_epoch = args.epoch

    cnn_feature_size = -1  #600 #128 #256 #
    # Except for VAEs...!

    stimulus_resize_dim = args.resizeDim  #64 #28

    normalize_rgb_values = False

    rgb_scaler = 1.0  #255.0
    from ReferentialGym.datasets.utils import ResizeNormalize
    transform = ResizeNormalize(size=stimulus_resize_dim,
                                normalize_rgb_values=normalize_rgb_values,
                                rgb_scaler=rgb_scaler)

    transform_degrees = 45
    transform_translate = (0.25, 0.25)

    multi_head_detached = args.detached_heads

    rg_config = {
        "observability": "partial",
        "max_sentence_length": args.max_sentence_length,
        "nbr_communication_round": 1,
        "nbr_distractors": {
            'train': args.nbr_train_distractors,
            'test': args.nbr_test_distractors
        },
        "distractor_sampling": args.distractor_sampling,
        # Default: use 'similarity-0.5'
        # otherwise the emerging language
        # will have very high ambiguity...
        # Speakers find the strategy of uttering
        # a word that is relevant to the class/label
        # of the target, seemingly.
        "descriptive": False,
        "descriptive_target_ratio": 0.97,
        # Default: 1-(1/(nbr_distractors+2)),
        # otherwise the agent find the local minimum
        # where it only predicts 'no-target'...
        "object_centric": False,
        "nbr_stimulus": 1,
        "graphtype": args.graphtype,
        "tau0": args.st_gs_inv_tau0,
        "gumbel_softmax_eps": 1e-6,
        "vocab_size": args.vocab_size,
        "symbol_embedding_size": 256,  #64
        "agent_architecture": args.
        arch,  #'CoordResNet18AvgPooled-2', #'BetaVAE', #'ParallelMONet', #'BetaVAE', #'CNN[-MHDPA]'/'[pretrained-]ResNet18[-MHDPA]-2'
        "agent_learning":
        'learning',  #'transfer_learning' : CNN's outputs are detached from the graph...
        "agent_loss_type": args.agent_loss_type,  #'NLL'
        "cultural_pressure_it_period": None,
        "cultural_speaker_substrate_size": 1,
        "cultural_listener_substrate_size": 1,
        "cultural_reset_strategy":
        "oldestL",  # "uniformSL" #"meta-oldestL-SGD"
        "cultural_reset_meta_learning_rate": 1e-3,

        # Obverter's Cultural Bottleneck:
        "iterated_learning_scheme": args.iterated_learning_scheme,
        "iterated_learning_period": args.iterated_learning_period,
        "iterated_learning_rehearse_MDL": args.iterated_learning_rehearse_MDL,
        "iterated_learning_rehearse_MDL_factor":
        args.iterated_learning_rehearse_MDL_factor,
        "obverter_stop_threshold": 0.95,  #0.0 if not in use.
        "obverter_nbr_games_per_round": args.obverter_nbr_games_per_round,
        "obverter_least_effort_loss": False,
        "obverter_least_effort_loss_weights": [1.0 for x in range(0, 10)],
        "batch_size": args.batch_size,
        "dataloader_num_worker": args.dataloader_num_worker,
        "stimulus_depth_dim": 1 if 'dSprites' in args.dataset else 3,
        "stimulus_resize_dim": stimulus_resize_dim,
        "learning_rate": args.lr,  #1e-3,
        "adam_eps": 1e-8,
        "dropout_prob": args.dropout_prob,
        "embedding_dropout_prob": args.emb_dropout_prob,
        "with_gradient_clip": False,
        "gradient_clip": 1e0,
        "use_homoscedastic_multitasks_loss":
        args.homoscedastic_multitasks_loss,
        "use_feat_converter": args.use_feat_converter,
        "use_curriculum_nbr_distractors": args.use_curriculum_nbr_distractors,
        "curriculum_distractors_window_size": 25,  #100,
        "unsupervised_segmentation_factor": None,  #1e5
        "nbr_experience_repetition": args.nbr_experience_repetition,
        "with_utterance_penalization": False,
        "with_utterance_promotion": False,
        "utterance_oov_prob":
        0.5,  # Expected penalty of observing out-of-vocabulary words. 
        # The greater this value, the greater the loss/cost.
        "utterance_factor": 1e-2,
        "with_speaker_entropy_regularization": False,
        "with_listener_entropy_regularization": False,
        "entropy_regularization_factor": -1e-2,
        "with_mdl_principle": False,
        "mdl_principle_factor": 5e-2,
        "with_weight_maxl1_loss": False,
        "use_cuda": args.use_cuda,

        # "train_transform":          T.Compose([T.RandomAffine(degrees=transform_degrees,
        #                                                       translate=transform_translate,
        #                                                       scale=None,
        #                                                       shear=None,
        #                                                       resample=False,
        #                                                       fillcolor=0),
        #                                         transform]),

        # "test_transform":           T.Compose([T.RandomAffine(degrees=transform_degrees,
        #                                                      translate=transform_translate,
        #                                                      scale=None,
        #                                                      shear=None,
        #                                                      resample=False,
        #                                                      fillcolor=0),
        #                                         transform]),
        "train_transform": transform,
        "test_transform": transform,
    }

    ## Train set:
    train_split_strategy = args.train_test_split_strategy
    test_split_strategy = train_split_strategy

    ## Agent Configuration:
    agent_config = copy.deepcopy(rg_config)
    agent_config['use_cuda'] = rg_config['use_cuda']
    agent_config['homoscedastic_multitasks_loss'] = rg_config[
        'use_homoscedastic_multitasks_loss']
    agent_config['use_feat_converter'] = rg_config['use_feat_converter']
    agent_config['max_sentence_length'] = rg_config['max_sentence_length']
    agent_config['nbr_distractors'] = rg_config['nbr_distractors'][
        'train'] if rg_config['observability'] == 'full' else 0
    agent_config['nbr_stimulus'] = rg_config['nbr_stimulus']
    agent_config['nbr_communication_round'] = rg_config[
        'nbr_communication_round']
    agent_config['descriptive'] = rg_config['descriptive']
    agent_config['gumbel_softmax_eps'] = rg_config['gumbel_softmax_eps']
    agent_config['agent_learning'] = rg_config['agent_learning']

    # Obverter:
    agent_config[
        'use_obverter_threshold_to_stop_message_generation'] = args.obverter_threshold_to_stop_message_generation

    agent_config['symbol_embedding_size'] = rg_config['symbol_embedding_size']

    # Recurrent Convolutional Architecture:
    agent_config['architecture'] = rg_config['agent_architecture']
    agent_config['dropout_prob'] = rg_config['dropout_prob']
    agent_config['embedding_dropout_prob'] = rg_config[
        'embedding_dropout_prob']

    if 'Santoro2017-SoC' in agent_config['architecture']:
        # For a fair comparison between CNN an VAEs:
        # the CNN is augmented with one final FC layer reducing to the latent space shape.
        # Need to use feat converter too:
        #rg_config['use_feat_converter'] = True
        #agent_config['use_feat_converter'] = True

        # Otherwise, the VAE alone may be augmented:
        # This approach assumes that the VAE latent dimension size
        # is acting as a prior which is part of the comparison...
        rg_config['use_feat_converter'] = False
        agent_config['use_feat_converter'] = False

        agent_config['cnn_encoder_channels'] = [
            'BN32', 'BN64', 'BN128', 'BN256'
        ]
        if '3x3' in agent_config['architecture']:
            agent_config['cnn_encoder_kernels'] = [3, 3, 3, 3]
        elif '7x4x4x3' in agent_config['architecture']:
            agent_config['cnn_encoder_kernels'] = [7, 4, 4, 3]
        else:
            agent_config['cnn_encoder_kernels'] = [4, 4, 4, 4]
        agent_config['cnn_encoder_strides'] = [2, 2, 2, 2]
        agent_config['cnn_encoder_paddings'] = [1, 1, 1, 1]
        agent_config['cnn_encoder_fc_hidden_units'] = []
        # the last FC layer is provided by the cnn_encoder_feature_dim parameter below...

        # For a fair comparison between CNN an VAEs:
        #agent_config['cnn_encoder_feature_dim'] = args.vae_nbr_latent_dim
        # Otherwise:
        cnn_feature_size = 100
        agent_config['cnn_encoder_feature_dim'] = cnn_feature_size
        # N.B.: if cnn_encoder_fc_hidden_units is [],
        # then this last parameter does not matter.
        # The cnn encoder is not topped by a FC network.

        agent_config['cnn_encoder_mini_batch_size'] = args.mini_batch_size
        agent_config['feat_converter_output_size'] = cnn_feature_size

        if 'MHDPA' in agent_config['architecture']:
            agent_config['mhdpa_nbr_head'] = 4
            agent_config['mhdpa_nbr_rec_update'] = 1
            agent_config['mhdpa_nbr_mlp_unit'] = 256
            agent_config['mhdpa_interaction_dim'] = 128

        agent_config['temporal_encoder_nbr_hidden_units'] = rg_config[
            'nbr_stimulus'] * cnn_feature_size
        agent_config['temporal_encoder_nbr_rnn_layers'] = 0
        agent_config['temporal_encoder_mini_batch_size'] = args.mini_batch_size

        agent_config['transcoder_nbr_hidden_units'] = 256
        agent_config['transcoder_nbr_rnn_layers'] = 1
        agent_config['transcoder_attention_interaction_dim'] = 128
        agent_config[
            'transcoder_speaker_soft_attention'] = args.transcoder_speaker_soft_attention
        agent_config[
            'transcoder_listener_soft_attention'] = args.transcoder_listener_soft_attention
        agent_config[
            'transcoder_st_gs_inv_tau0'] = args.transcoder_st_gs_inv_tau0
        agent_config[
            'transcoder_visual_encoder_use_coord4'] = args.transcoder_visual_encoder_use_coord4

        agent_config['symbol_processing_nbr_hidden_units'] = 256
        agent_config['symbol_processing_nbr_rnn_layers'] = 1

        # Transcoding Listener:
        agent_config['textual_encoder_nbr_hidden_units'] = 256
        agent_config['textual_encoder_nbr_rnn_layers'] = 1

        agent_config[
            'visual_decoder_nbr_steps'] = args.visual_decoder_nbr_steps
        agent_config['visual_decoder_nbr_hidden_units'] = 256
        agent_config['visual_decoder_nbr_rnn_layers'] = 1

        agent_config['visual_decoder_mlp_dropout_prob'] = 0.0

    elif 'Santoro2017-CLEVR' in agent_config['architecture']:
        # For a fair comparison between CNN an VAEs:
        # the CNN is augmented with one final FC layer reducing to the latent space shape.
        # Need to use feat converter too:
        #rg_config['use_feat_converter'] = True
        #agent_config['use_feat_converter'] = True

        # Otherwise, the VAE alone may be augmented:
        # This approach assumes that the VAE latent dimension size
        # is acting as a prior which is part of the comparison...
        rg_config['use_feat_converter'] = False
        agent_config['use_feat_converter'] = False

        agent_config['cnn_encoder_channels'] = ['BN24', 'BN24', 'BN24', 'BN24']
        if '3x3' in agent_config['architecture']:
            agent_config['cnn_encoder_kernels'] = [3, 3, 3, 3]
        elif '7x4x4x3' in agent_config['architecture']:
            agent_config['cnn_encoder_kernels'] = [7, 4, 4, 3]
        else:
            agent_config['cnn_encoder_kernels'] = [4, 4, 4, 4]
        agent_config['cnn_encoder_strides'] = [2, 2, 2, 2]
        agent_config['cnn_encoder_paddings'] = [1, 1, 1, 1]
        agent_config['cnn_encoder_fc_hidden_units'] = []
        # the last FC layer is provided by the cnn_encoder_feature_dim parameter below...

        # For a fair comparison between CNN an VAEs:
        #agent_config['cnn_encoder_feature_dim'] = args.vae_nbr_latent_dim
        # Otherwise:
        agent_config['cnn_encoder_feature_dim'] = cnn_feature_size
        # N.B.: if cnn_encoder_fc_hidden_units is [],
        # then this last parameter does not matter.
        # The cnn encoder is not topped by a FC network.

        agent_config['cnn_encoder_mini_batch_size'] = args.mini_batch_size
        agent_config['feat_converter_output_size'] = cnn_feature_size

        if 'MHDPA' in agent_config['architecture']:
            agent_config['mhdpa_nbr_head'] = 4
            agent_config['mhdpa_nbr_rec_update'] = 1
            agent_config['mhdpa_nbr_mlp_unit'] = 256
            agent_config['mhdpa_interaction_dim'] = 128

        agent_config['temporal_encoder_nbr_hidden_units'] = 0
        agent_config['temporal_encoder_nbr_rnn_layers'] = 0
        agent_config['temporal_encoder_mini_batch_size'] = args.mini_batch_size

        agent_config['transcoder_nbr_hidden_units'] = 256
        agent_config['transcoder_nbr_rnn_layers'] = 1
        agent_config['transcoder_attention_interaction_dim'] = 128
        agent_config[
            'transcoder_speaker_soft_attention'] = args.transcoder_speaker_soft_attention
        agent_config[
            'transcoder_listener_soft_attention'] = args.transcoder_listener_soft_attention
        agent_config[
            'transcoder_st_gs_inv_tau0'] = args.transcoder_st_gs_inv_tau0
        agent_config[
            'transcoder_visual_encoder_use_coord4'] = args.transcoder_visual_encoder_use_coord4

        agent_config['symbol_processing_nbr_hidden_units'] = 256
        agent_config['symbol_processing_nbr_rnn_layers'] = 1

        # Transcoding Listener:
        agent_config['textual_encoder_nbr_hidden_units'] = 256
        agent_config['textual_encoder_nbr_rnn_layers'] = 1

        agent_config[
            'visual_decoder_nbr_steps'] = args.visual_decoder_nbr_steps
        agent_config['visual_decoder_nbr_hidden_units'] = 256
        agent_config['visual_decoder_nbr_rnn_layers'] = 1

        agent_config['visual_decoder_mlp_dropout_prob'] = 0.0

    elif 'CNN' in agent_config['architecture']:
        rg_config['use_feat_converter'] = False
        agent_config['use_feat_converter'] = False

        if 'BN' in args.arch:
            #agent_config['cnn_encoder_channels'] = ['BN32','BN32','BN64','BN64']
            agent_config['cnn_encoder_channels'] = ['BN32', 'BN64', 'BN128']
        else:
            #agent_config['cnn_encoder_channels'] = [32,32,64,64]
            agent_config['cnn_encoder_channels'] = [32, 64, 128]

        if '3x3' in agent_config['architecture']:
            agent_config['cnn_encoder_kernels'] = [3, 3, 3]
        elif '7x4x4x3' in agent_config['architecture']:
            agent_config['cnn_encoder_kernels'] = [7, 4, 3]
        else:
            agent_config['cnn_encoder_kernels'] = [4, 4, 4]
        agent_config['cnn_encoder_strides'] = [2, 2, 2]
        agent_config['cnn_encoder_paddings'] = [1, 1, 1]
        agent_config['cnn_encoder_fc_hidden_units'] = []  #[128,]
        # the last FC layer is provided by the cnn_encoder_feature_dim parameter below...

        # For a fair comparison between CNN an VAEs:
        #agent_config['cnn_encoder_feature_dim'] = args.vae_nbr_latent_dim
        agent_config['cnn_encoder_feature_dim'] = cnn_feature_size
        # N.B.: if cnn_encoder_fc_hidden_units is [],
        # then this last parameter does not matter.
        # The cnn encoder is not topped by a FC network.

        agent_config['cnn_encoder_mini_batch_size'] = args.mini_batch_size
        agent_config['feat_converter_output_size'] = cnn_feature_size

        if 'MHDPA' in agent_config['architecture']:
            agent_config['mhdpa_nbr_head'] = 4
            agent_config['mhdpa_nbr_rec_update'] = 1
            agent_config['mhdpa_nbr_mlp_unit'] = 256
            agent_config['mhdpa_interaction_dim'] = 128

        agent_config['temporal_encoder_nbr_hidden_units'] = 0
        agent_config['temporal_encoder_nbr_rnn_layers'] = 0
        agent_config['temporal_encoder_mini_batch_size'] = args.mini_batch_size

        agent_config['transcoder_nbr_hidden_units'] = 256
        agent_config['transcoder_nbr_rnn_layers'] = 1
        agent_config['transcoder_attention_interaction_dim'] = 128
        agent_config[
            'transcoder_speaker_soft_attention'] = args.transcoder_speaker_soft_attention
        agent_config[
            'transcoder_listener_soft_attention'] = args.transcoder_listener_soft_attention
        agent_config[
            'transcoder_st_gs_inv_tau0'] = args.transcoder_st_gs_inv_tau0
        agent_config[
            'transcoder_visual_encoder_use_coord4'] = args.transcoder_visual_encoder_use_coord4

        agent_config['symbol_processing_nbr_hidden_units'] = 256
        agent_config['symbol_processing_nbr_rnn_layers'] = 1

        # Transcoding Listener:
        agent_config['textual_encoder_nbr_hidden_units'] = 256
        agent_config['textual_encoder_nbr_rnn_layers'] = 1

        agent_config[
            'visual_decoder_nbr_steps'] = args.visual_decoder_nbr_steps
        agent_config['visual_decoder_nbr_hidden_units'] = 256
        agent_config['visual_decoder_nbr_rnn_layers'] = 1

        agent_config['visual_decoder_mlp_dropout_prob'] = 0.0

    else:
        raise NotImplementedError

    save_path = f"./{args.dataset}+DualLabeled/{'Attached' if not(multi_head_detached) else 'Detached'}Heads"
    save_path += f"/{nbr_epoch}Ep_Emb{rg_config['symbol_embedding_size']}_CNN{cnn_feature_size}to{args.vae_nbr_latent_dim}"
    if args.shared_architecture:
        save_path += "/shared_architecture"
    save_path += f"/TrainNOTF_TestNOTF/"
    save_path += f"Dropout{rg_config['dropout_prob']}_DPEmb{rg_config['embedding_dropout_prob']}"
    save_path += f"_BN_{rg_config['agent_learning']}/"
    save_path += f"{rg_config['agent_loss_type']}"

    if 'dSprites' in args.dataset:
        train_test_strategy = f"-{test_split_strategy}"
        if test_split_strategy != train_split_strategy:
            train_test_strategy = f"/train_{train_split_strategy}/test_{test_split_strategy}"
        save_path += f"/dSprites{train_test_strategy}"

    save_path += f"/OBS{rg_config['stimulus_resize_dim']}X{rg_config['stimulus_depth_dim']}C-Rep{rg_config['nbr_experience_repetition']}"

    if rg_config['use_curriculum_nbr_distractors']:
        save_path += f"+W{rg_config['curriculum_distractors_window_size']}Curr"
    if rg_config['with_utterance_penalization']:
        save_path += "+Tau-10-OOV{}PenProb{}".format(
            rg_config['utterance_factor'], rg_config['utterance_oov_prob'])
    if rg_config['with_utterance_promotion']:
        save_path += "+Tau-10-OOV{}ProProb{}".format(
            rg_config['utterance_factor'], rg_config['utterance_oov_prob'])

    if rg_config['with_gradient_clip']:
        save_path += '+ClipGrad{}'.format(rg_config['gradient_clip'])

    if rg_config['with_speaker_entropy_regularization']:
        save_path += 'SPEntrReg{}'.format(
            rg_config['entropy_regularization_factor'])
    if rg_config['with_listener_entropy_regularization']:
        save_path += 'LSEntrReg{}'.format(
            rg_config['entropy_regularization_factor'])

    if rg_config['iterated_learning_scheme']:
        save_path += f"-ILM{rg_config['iterated_learning_period']}{'+RehearseMDL{}'.format(rg_config['iterated_learning_rehearse_MDL_factor']) if rg_config['iterated_learning_rehearse_MDL'] else ''}"

    if rg_config['with_mdl_principle']:
        save_path += '-MDL{}'.format(rg_config['mdl_principle_factor'])

    if rg_config['cultural_pressure_it_period'] != 'None':
        save_path += '-S{}L{}-{}-Reset{}'.\
          format(rg_config['cultural_speaker_substrate_size'],
          rg_config['cultural_listener_substrate_size'],
          rg_config['cultural_pressure_it_period'],
          rg_config['cultural_reset_strategy']+str(rg_config['cultural_reset_meta_learning_rate']) if 'meta' in rg_config['cultural_reset_strategy'] else rg_config['cultural_reset_strategy'])

    save_path += '-{}{}CulturalAgent-SEED{}-{}-obs_b{}_minib{}_lr{}-{}-tau0-{}-{}DistrTrain{}Test{}-stim{}-vocab{}over{}_{}{}'.\
      format(
      'ObjectCentric' if rg_config['object_centric'] else '',
      'Descriptive{}'.format(rg_config['descriptive_target_ratio']) if rg_config['descriptive'] else '',
      seed,
      rg_config['observability'],
      rg_config['batch_size'],
      args.mini_batch_size,
      rg_config['learning_rate'],
      rg_config['graphtype'],
      rg_config['tau0'],
      rg_config['distractor_sampling'],
      *rg_config['nbr_distractors'].values(),
      rg_config['nbr_stimulus'],
      rg_config['vocab_size'],
      rg_config['max_sentence_length'],
      rg_config['agent_architecture'],
      f"/{'Detached' if args.vae_detached_featout else ''}beta{vae_beta}-factor{factor_vae_gamma}" if 'BetaVAE' in rg_config['agent_architecture'] else ''
    )

    if 'MONet' in rg_config['agent_architecture'] or 'BetaVAE' in rg_config[
            'agent_architecture']:
        save_path += f"beta{vae_beta}-factor{factor_vae_gamma}-gamma{monet_gamma}-sigma{vae_observation_sigma}" if 'MONet' in rg_config[
            'agent_architecture'] else ''
        save_path += f"CEMC{maxCap}over{nbrepochtillmaxcap}" if vae_constrainedEncoding else ''
        save_path += f"UnsupSeg{rg_config['unsupervised_segmentation_factor']}" if rg_config[
            'unsupervised_segmentation_factor'] is not None else ''
        save_path += f"LossVAECoeff{args.vae_lambda}_{'UseMu' if args.vae_use_mu_value else ''}"

    if rg_config['use_feat_converter']:
        save_path += f"+FEATCONV"

    if rg_config['use_homoscedastic_multitasks_loss']:
        save_path += '+H**o'

    save_path += f"/{args.optimizer_type}/"

    if 'reinforce' in args.graphtype:
        save_path += f'/REINFORCE_EntropyCoeffNeg1m3/UnnormalizedDetLearningSignalHavrylovLoss/NegPG/'

    if 'obverter' in args.graphtype:
        save_path += f"withPopulationHandlerModule/Obverter{args.obverter_threshold_to_stop_message_generation}-{args.obverter_nbr_games_per_round}GPR/DEBUG/"
    else:
        save_path += f"withPopulationHandlerModule/STGS-{args.agent_type}-LSTM-CNN-Agent/"

    if 'Transcoding' in args.agent_type and 'TranscodingSpeaker' not in args.agent_type:
        save_path += f"{'Soft' if args.transcoder_listener_soft_attention else ''}TranscodingSteps{args.visual_decoder_nbr_steps}"
        save_path += f"-TransListSTGS{args.transcoder_st_gs_inv_tau0}/"

    if 'Transcoding' in args.agent_type and 'TranscodingListener' not in args.agent_type:
        save_path += f"{'Soft' if args.transcoder_speaker_soft_attention else ''}TranscodingSpeaker-Coord{'4' if args.transcoder_visual_encoder_use_coord4 else '2' }/"

    save_path += f"Periodic{args.metric_epoch_period}TS+DISComp-{'fast-' if args.metric_fast else ''}/TestTranscoding/"

    if args.same_head:
        save_path += "same_head/"

    if args.test_id_analogy:
        save_path += 'withAnalogyTest/'
    else:
        save_path += 'NoAnalogyTest/'

    save_path += f'DatasetRepTrain{args.nbr_train_dataset_repetition}Test{args.nbr_test_dataset_repetition}'

    rg_config['save_path'] = save_path

    print(save_path)

    from ReferentialGym.utils import statsLogger
    logger = statsLogger(path=save_path, dumpPeriod=100)

    # # Agents
    batch_size = 4
    nbr_distractors = 1 if 'partial' in rg_config[
        'observability'] else agent_config['nbr_distractors']['train']
    nbr_stimulus = agent_config['nbr_stimulus']
    obs_shape = [
        nbr_distractors + 1, nbr_stimulus, rg_config['stimulus_depth_dim'],
        rg_config['stimulus_resize_dim'], rg_config['stimulus_resize_dim']
    ]
    vocab_size = rg_config['vocab_size']
    max_sentence_length = rg_config['max_sentence_length']

    if 'obverter' in args.graphtype:
        from ReferentialGym.agents import DifferentiableObverterAgent
        speaker = DifferentiableObverterAgent(
            kwargs=agent_config,
            obs_shape=obs_shape,
            vocab_size=vocab_size,
            max_sentence_length=max_sentence_length,
            agent_id='s0',
            logger=logger,
            use_sentences_one_hot_vectors=args.use_sentences_one_hot_vectors,
            differentiable=args.differentiable)
    else:
        if 'EoSPriored' in args.agent_type:
            from ReferentialGym.agents import EoSPrioredLSTMCNNSpeaker
            speaker = EoSPrioredLSTMCNNSpeaker(
                kwargs=agent_config,
                obs_shape=obs_shape,
                vocab_size=vocab_size,
                max_sentence_length=max_sentence_length,
                agent_id='s0',
                logger=logger)
        elif 'Transcoding' in args.agent_type and 'Listener' not in args.agent_type:
            from ReferentialGym.agents import TranscodingLSTMCNNSpeaker
            speaker = TranscodingLSTMCNNSpeaker(
                kwargs=agent_config,
                obs_shape=obs_shape,
                vocab_size=vocab_size,
                max_sentence_length=max_sentence_length,
                agent_id='s0',
                logger=logger)
        else:
            from ReferentialGym.agents import LSTMCNNSpeaker
            speaker = LSTMCNNSpeaker(kwargs=agent_config,
                                     obs_shape=obs_shape,
                                     vocab_size=vocab_size,
                                     max_sentence_length=max_sentence_length,
                                     agent_id='s0',
                                     logger=logger)

    print("Speaker:", speaker)

    for name, param in speaker.named_parameters():
        print(name, param.shape)

    listener_config = copy.deepcopy(agent_config)
    if args.shared_architecture:
        listener_config['cnn_encoder'] = speaker.cnn_encoder
    listener_config['nbr_distractors'] = rg_config['nbr_distractors']['train']
    batch_size = 4
    nbr_distractors = listener_config['nbr_distractors']
    nbr_stimulus = listener_config['nbr_stimulus']
    obs_shape = [
        nbr_distractors + 1, nbr_stimulus, rg_config['stimulus_depth_dim'],
        rg_config['stimulus_resize_dim'], rg_config['stimulus_resize_dim']
    ]
    vocab_size = rg_config['vocab_size']
    max_sentence_length = rg_config['max_sentence_length']

    if 'obverter' in args.graphtype:
        listener = DifferentiableObverterAgent(
            kwargs=listener_config,
            obs_shape=obs_shape,
            vocab_size=vocab_size,
            max_sentence_length=max_sentence_length,
            agent_id='l0',
            logger=logger,
            use_sentences_one_hot_vectors=args.use_sentences_one_hot_vectors,
            differentiable=args.differentiable)
    elif 'TranscodingListener' in args.agent_type and 'Speaker' not in args.agent_type:
        from ReferentialGym.agents import TranscodingLSTMCNNListener
        listener = TranscodingLSTMCNNListener(
            kwargs=listener_config,
            obs_shape=obs_shape,
            vocab_size=vocab_size,
            max_sentence_length=max_sentence_length,
            agent_id='l0',
            logger=logger)
    else:
        from ReferentialGym.agents import LSTMCNNListener
        listener = LSTMCNNListener(kwargs=listener_config,
                                   obs_shape=obs_shape,
                                   vocab_size=vocab_size,
                                   max_sentence_length=max_sentence_length,
                                   agent_id='l0',
                                   logger=logger)
    print("Listener:", listener)

    for name, param in listener.named_parameters():
        print(name, param.shape)

    # # Dataset:
    need_dict_wrapping = {}

    if 'dSprites' in args.dataset:
        root = './datasets/dsprites-dataset'
        train_dataset = ReferentialGym.datasets.dSpritesDataset(
            root=root,
            train=True,
            transform=rg_config['train_transform'],
            split_strategy=train_split_strategy)
        test_dataset = ReferentialGym.datasets.dSpritesDataset(
            root=root,
            train=False,
            transform=rg_config['test_transform'],
            split_strategy=test_split_strategy)
    else:
        raise NotImplementedError

    ## Modules:
    modules = {}

    from ReferentialGym import modules as rg_modules

    # Population:
    population_handler_id = "population_handler_0"
    population_handler_config = rg_config
    population_handler_stream_ids = {
        "modules:current_speaker": "current_speaker_streams_dict",
        "modules:current_listener": "current_listener_streams_dict",
        "signals:epoch": "epoch",
        "signals:mode": "mode",
        "signals:global_it_datasample": "global_it_datasample",
    }

    # Current Speaker:
    current_speaker_id = "current_speaker"

    # Current Listener:
    current_listener_id = "current_listener"

    modules[population_handler_id] = rg_modules.build_PopulationHandlerModule(
        id=population_handler_id,
        prototype_speaker=speaker,
        prototype_listener=listener,
        config=population_handler_config,
        input_stream_ids=population_handler_stream_ids)

    modules[current_speaker_id] = rg_modules.CurrentAgentModule(
        id=current_speaker_id, role="speaker")
    modules[current_listener_id] = rg_modules.CurrentAgentModule(
        id=current_listener_id, role="listener")

    homo_id = "homo0"
    homo_config = {"use_cuda": args.use_cuda}
    if args.homoscedastic_multitasks_loss:
        modules[homo_id] = rg_modules.build_HomoscedasticMultiTasksLossModule(
            id=homo_id,
            config=homo_config,
        )

    ## Pipelines:
    pipelines = {}

    # 0) Now that all the modules are known, let us build the optimization module:
    optim_id = "global_optim"
    optim_config = {
        "modules": modules,
        "learning_rate": args.lr,
        "optimizer_type": args.optimizer_type,
        "with_gradient_clip": rg_config["with_gradient_clip"],
        "adam_eps": rg_config["adam_eps"],
    }

    optim_module = rg_modules.build_OptimizationModule(
        id=optim_id,
        config=optim_config,
    )
    modules[optim_id] = optim_module

    grad_recorder_id = "grad_recorder"
    grad_recorder_module = rg_modules.build_GradRecorderModule(
        id=grad_recorder_id)
    modules[grad_recorder_id] = grad_recorder_module

    topo_sim_metric_id = "topo_sim_metric"
    topo_sim_metric_module = rg_modules.build_TopographicSimilarityMetricModule(
        id=topo_sim_metric_id,
        config={
            "parallel_TS_computation_max_workers": 16,
            "epoch_period": args.metric_epoch_period,
            "fast": args.metric_fast,
            "verbose": False,
            "vocab_size": rg_config["vocab_size"],
        })
    modules[topo_sim_metric_id] = topo_sim_metric_module

    inst_coord_metric_id = "inst_coord_metric"
    inst_coord_metric_module = rg_modules.build_InstantaneousCoordinationMetricModule(
        id=inst_coord_metric_id, config={
            "epoch_period": 1,
        })
    modules[inst_coord_metric_id] = inst_coord_metric_module

    speaker_factor_vae_disentanglement_metric_id = "speaker_factor_vae_disentanglement_metric"
    speaker_factor_vae_disentanglement_metric_input_stream_ids = {
        'modules:current_speaker:ref:ref_agent:cnn_encoder': 'model',
        'modules:current_speaker:ref:ref_agent:features': 'representations',
        'current_dataloader:sample:speaker_experiences': 'experiences',
        'current_dataloader:sample:speaker_exp_latents':
        'latent_representations',
        'current_dataloader:sample:speaker_exp_latents_values':
        'latent_values_representations',
        'current_dataloader:sample:speaker_indices': 'indices',
    }
    speaker_factor_vae_disentanglement_metric_module = rg_modules.build_FactorVAEDisentanglementMetricModule(
        id=speaker_factor_vae_disentanglement_metric_id,
        input_stream_ids=
        speaker_factor_vae_disentanglement_metric_input_stream_ids,
        config={
            "epoch_period": args.metric_epoch_period,
            "batch_size": 64,  #5,
            "nbr_train_points": 10000,  #3000,
            "nbr_eval_points": 5000,  #2000,
            "resample": False,
            "threshold": 5e-2,  #0.0,#1.0,
            "random_state_seed": args.seed,
            "verbose": False,
            "active_factors_only": True,
        })
    modules[
        speaker_factor_vae_disentanglement_metric_id] = speaker_factor_vae_disentanglement_metric_module

    listener_factor_vae_disentanglement_metric_id = "listener_factor_vae_disentanglement_metric"
    listener_factor_vae_disentanglement_metric_input_stream_ids = {
        'modules:current_listener:ref:ref_agent:cnn_encoder': 'model',
        'modules:current_listener:ref:ref_agent:features': 'representations',
        'current_dataloader:sample:listener_experiences': 'experiences',
        'current_dataloader:sample:listener_exp_latents':
        'latent_representations',
        'current_dataloader:sample:listener_exp_latents_values':
        'latent_values_representations',
        'current_dataloader:sample:listener_indices': 'indices',
    }
    listener_factor_vae_disentanglement_metric_module = rg_modules.build_FactorVAEDisentanglementMetricModule(
        id=listener_factor_vae_disentanglement_metric_id,
        input_stream_ids=
        listener_factor_vae_disentanglement_metric_input_stream_ids,
        config={
            "epoch_period": args.metric_epoch_period,
            "batch_size": 64,  #5,
            "nbr_train_points": 10000,  #3000,
            "nbr_eval_points": 5000,  #2000,
            "resample": False,
            "threshold": 5e-2,  #0.0,#1.0,
            "random_state_seed": args.seed,
            "verbose": False,
            "active_factors_only": True,
        })
    modules[
        listener_factor_vae_disentanglement_metric_id] = listener_factor_vae_disentanglement_metric_module

    logger_id = "per_epoch_logger"
    logger_module = rg_modules.build_PerEpochLoggerModule(id=logger_id)
    modules[logger_id] = logger_module

    pipelines['referential_game'] = [
        population_handler_id, current_speaker_id, current_listener_id
    ]

    pipelines[optim_id] = []
    if args.homoscedastic_multitasks_loss:
        pipelines[optim_id].append(homo_id)
    pipelines[optim_id].append(optim_id)
    '''
  # Add gradient recorder module for debugging purposes:
  pipelines[optim_id].append(grad_recorder_id)
  '''
    pipelines[optim_id].append(speaker_factor_vae_disentanglement_metric_id)
    pipelines[optim_id].append(listener_factor_vae_disentanglement_metric_id)
    pipelines[optim_id].append(topo_sim_metric_id)
    pipelines[optim_id].append(inst_coord_metric_id)
    pipelines[optim_id].append(logger_id)

    rg_config["modules"] = modules
    rg_config["pipelines"] = pipelines

    dataset_args = {
        "dataset_class": "DualLabeledDataset",
        "modes": {
            "train": train_dataset,
            "test": test_dataset,
        },
        "need_dict_wrapping": need_dict_wrapping,
        "nbr_stimulus": rg_config['nbr_stimulus'],
        "distractor_sampling": rg_config['distractor_sampling'],
        "nbr_distractors": rg_config['nbr_distractors'],
        "observability": rg_config['observability'],
        "object_centric": rg_config['object_centric'],
        "descriptive": rg_config['descriptive'],
        "descriptive_target_ratio": rg_config['descriptive_target_ratio'],
    }

    refgame = ReferentialGym.make(config=rg_config, dataset_args=dataset_args)

    # In[22]:

    refgame.train(nbr_epoch=nbr_epoch, logger=logger, verbose_period=1)

    logger.flush()
def main():
    parser = argparse.ArgumentParser(
        description='LSTM CNN Agents: Example Language Emergence.')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--use_cuda', action='store_true', default=False)
    parser.add_argument('--dataset',
                        type=str,
                        choices=[
                            'CIFAR10',
                            'dSprites',
                        ],
                        help='dataset to train on.',
                        default='CIFAR10')
    parser.add_argument('--arch',
                        type=str,
                        choices=[
                            'CNN',
                            'CNN3x3',
                            'BN+CNN',
                            'BN+CNN3x3',
                            'BN+Coord2CNN3x3',
                            'BN+Coord4CNN3x3',
                        ],
                        help='model architecture to train',
                        default="BN+CNN3x3")
    parser.add_argument(
        '--graphtype',
        type=str,
        choices=[
            'straight_through_gumbel_softmax', 'reinforce',
            'baseline_reduced_reinforce', 'normalized_reinforce',
            'baseline_reduced_normalized_reinforce', 'max_entr_reinforce',
            'baseline_reduced_normalized_max_entr_reinforce',
            'argmax_reinforce', 'obverter'
        ],
        help=
        'type of graph to use during training of the speaker and listener.',
        default='straight_through_gumbel_softmax')
    parser.add_argument('--max_sentence_length', type=int, default=20)
    parser.add_argument('--vocab_size', type=int, default=100)
    parser.add_argument('--symbol_embedding_size', type=int, default=256)
    parser.add_argument('--optimizer_type',
                        type=str,
                        choices=["adam", "sgd"],
                        default="adam")
    parser.add_argument('--agent_loss_type',
                        type=str,
                        choices=[
                            "Hinge",
                            "NLL",
                            "CE",
                        ],
                        default="Hinge")
    parser.add_argument('--agent_type',
                        type=str,
                        choices=[
                            "Baseline",
                            "EoSPriored",
                        ],
                        default="Baseline")
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--metric_epoch_period', type=int, default=20)
    parser.add_argument('--metric_fast', action='store_true', default=True)
    parser.add_argument('--dataloader_num_worker', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--mini_batch_size', type=int, default=128)
    parser.add_argument('--dropout_prob', type=float, default=0.0)
    parser.add_argument('--embedding_dropout_prob', type=float, default=0.8)
    parser.add_argument('--nbr_test_distractors', type=int, default=63)
    parser.add_argument('--nbr_train_distractors', type=int, default=47)
    parser.add_argument('--resizeDim',
                        default=32,
                        type=int,
                        help='input image resize')
    parser.add_argument('--shared_architecture',
                        action='store_true',
                        default=False)
    parser.add_argument('--homoscedastic_multitasks_loss',
                        action='store_true',
                        default=False)
    parser.add_argument('--use_curriculum_nbr_distractors',
                        action='store_true',
                        default=False)
    parser.add_argument('--descriptive', action='store_true', default=False)
    parser.add_argument('--distractor_sampling',
                        type=str,
                        choices=[
                            "uniform",
                            "similarity-0.98",
                            "similarity-0.90",
                            "similarity-0.75",
                        ],
                        default="similarity-0.75")

    # Dataset Hyperparameters:
    parser.add_argument(
        '--train_test_split_strategy',
        type=str,
        choices=[  # Heart shape: interpolation:
            'combinatorial2-Y-4-2-X-4-2-Orientation-40-N-Scale-6-N-Shape-3-N',  #Sparse 2 Attributes: X+Y 64 imgs, 48 train, 16 test
            'combinatorial2-Y-2-2-X-2-2-Orientation-40-N-Scale-6-N-Shape-3-N',  #Dense 2 Attributes: X+Y 256 imgs, 192 train, 64 test
            'combinatorial2-Y-8-2-X-8-2-Orientation-10-2-Scale-1-2-Shape-3-N',  #COMB2:Sparser 4 Attributes: 264 test / 120 train
            'combinatorial2-Y-4-2-X-4-2-Orientation-5-2-Scale-1-2-Shape-3-N',  #COMB2:Sparse 4 Attributes: 2112 test / 960 train
            'combinatorial2-Y-2-2-X-2-2-Orientation-2-2-Scale-1-2-Shape-3-N',  #COMB2:Dense 4 Attributes: ? test / ? train
            'combinatorial2-Y-4-2-X-4-2-Orientation-5-2-Scale-6-N-Shape-3-N',  #COMB2 Sparse: 3 Attributes: XYOrientation 256 test / 256 train
            # Heart shape: Extrapolation:
            'combinatorial2-Y-4-S4-X-4-S4-Orientation-40-N-Scale-6-N-Shape-3-N',  #Sparse 2 Attributes: X+Y 64 imgs, 48 train, 16 test
            'combinatorial2-Y-8-S2-X-8-S2-Orientation-10-S2-Scale-1-S3-Shape-3-N',  #COMB2:Sparser 4 Attributes: 264 test / 120 train
            'combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-1-S3-Shape-3-N',  #COMB2:Sparse 4 Attributes: 2112 test / 960 train
            'combinatorial2-Y-2-S8-X-2-S8-Orientation-2-S10-Scale-1-S3-Shape-3-N',  #COMB2:Dense 4 Attributes: ? test / ? train
            'combinatorial2-Y-4-S4-X-4-S4-Orientation-5-S4-Scale-6-N-Shape-3-N',  #COMB2 Sparse: 3 Attributes: XYOrientation 256 test / 256 train
        ],
        help='train/test split strategy',
        # INTER:
        default=
        'combinatorial2-Y-4-2-X-4-2-Orientation-40-N-Scale-6-N-Shape-3-N')
    parser.add_argument(
        '--fast',
        action='store_true',
        default=False,
        help=
        'Disable the deterministic CuDNN. It is likely to make the computation faster.'
    )

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    args = parser.parse_args()
    print(args)

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    seed = args.seed

    # Following: https://pytorch.org/docs/stable/notes/randomness.html
    torch.manual_seed(seed)
    if hasattr(torch.backends, 'cudnn') and not (args.fast):
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    np.random.seed(seed)
    random.seed(seed)
    # # Hyperparameters:

    nbr_epoch = args.epoch

    cnn_feature_size = -1
    stimulus_resize_dim = args.resizeDim
    normalize_rgb_values = False

    rgb_scaler = 1.0  #255.0
    from ReferentialGym.datasets.utils import ResizeNormalize
    transform = ResizeNormalize(size=stimulus_resize_dim,
                                normalize_rgb_values=normalize_rgb_values,
                                rgb_scaler=rgb_scaler)

    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------
    #--------------------------------------------------------------------------

    rg_config = {
        "vocab_size": args.vocab_size,
        "max_sentence_length": args.max_sentence_length,
        "nbr_communication_round": 1,
        "observability": "partial",
        "nbr_distractors": {
            'train': args.nbr_train_distractors,
            'test': args.nbr_test_distractors
        },
        "distractor_sampling": args.distractor_sampling,
        # Default: use 'similarity-0.5'
        # otherwise the emerging language
        # will have very high ambiguity...
        # Speakers find the strategy of uttering
        # a word that is relevant to the class/label
        # of the target, seemingly.
        "descriptive": args.descriptive,
        "descriptive_target_ratio":
        1 - (1 / (args.nbr_train_distractors + 2)),  #0.97, 
        # Default: 1-(1/(nbr_distractors+2)),
        # otherwise the agent find the local minimum
        # where it only predicts 'no-target'...
        "object_centric": False,
        "nbr_stimulus": 1,
        "graphtype": args.graphtype,
        "tau0": 0.2,
        "gumbel_softmax_eps": 1e-6,
        "agent_architecture": args.arch,
        "agent_learning": 'learning',
        "agent_loss_type": args.agent_loss_type,  #'NLL'

        # "cultural_pressure_it_period": None,
        # "cultural_speaker_substrate_size":  1,
        # "cultural_listener_substrate_size":  1,
        # "cultural_reset_strategy":  "oldestL", # "uniformSL" #"meta-oldestL-SGD"
        # "cultural_reset_meta_learning_rate":  1e-3,
        "batch_size": args.batch_size,
        "dataloader_num_worker": args.dataloader_num_worker,
        "stimulus_depth_dim": 1 if 'dSprites' in args.dataset else 3,
        "stimulus_resize_dim": stimulus_resize_dim,
        "learning_rate": args.lr,  #1e-3,
        "adam_eps": 1e-8,
        "with_gradient_clip": False,
        "gradient_clip": 1e0,
        "with_weight_maxl1_loss": False,
        "use_homoscedastic_multitasks_loss":
        args.homoscedastic_multitasks_loss,
        "use_cuda": args.use_cuda,
        "train_transform": transform,
        "test_transform": transform,
    }

    ## Train set:
    train_split_strategy = args.train_test_split_strategy
    test_split_strategy = train_split_strategy

    ## Agent Configuration:
    agent_config = copy.deepcopy(rg_config)
    agent_config['nbr_distractors'] = rg_config['nbr_distractors'][
        'train'] if rg_config['observability'] == 'full' else 0

    # Recurrent Convolutional Architecture:
    agent_config['architecture'] = rg_config['agent_architecture']
    agent_config['dropout_prob'] = args.dropout_prob
    agent_config['embedding_dropout_prob'] = args.embedding_dropout_prob
    agent_config['symbol_embedding_size'] = args.symbol_embedding_size

    if 'CNN' in agent_config['architecture']:
        rg_config['use_feat_converter'] = False
        agent_config['use_feat_converter'] = False

        if 'BN' in args.arch:
            agent_config['cnn_encoder_channels'] = [
                'BN32', 'BN32', 'BN64', 'BN64'
            ]
        else:
            agent_config['cnn_encoder_channels'] = [32, 32, 64, 64]

        if '3x3' in agent_config['architecture']:
            agent_config['cnn_encoder_kernels'] = [3, 3, 3, 3]
        elif '7x4x4x3' in agent_config['architecture']:
            agent_config['cnn_encoder_kernels'] = [7, 4, 4, 3]
        else:
            agent_config['cnn_encoder_kernels'] = [4, 4, 4, 4]
        agent_config['cnn_encoder_strides'] = [2, 2, 2, 2]
        agent_config['cnn_encoder_paddings'] = [1, 1, 1, 1]
        agent_config['cnn_encoder_fc_hidden_units'] = []  #[128,]
        # the last FC layer is provided by the cnn_encoder_feature_dim parameter below...

        # For a fair comparison between CNN an VAEs:
        #agent_config['cnn_encoder_feature_dim'] = args.vae_nbr_latent_dim
        agent_config['cnn_encoder_feature_dim'] = cnn_feature_size
        # N.B.: if cnn_encoder_fc_hidden_units is [],
        # then this last parameter does not matter.
        # The cnn encoder is not topped by a FC network.

        agent_config['cnn_encoder_mini_batch_size'] = args.mini_batch_size
        agent_config['feat_converter_output_size'] = cnn_feature_size

        if 'MHDPA' in agent_config['architecture']:
            agent_config['mhdpa_nbr_head'] = 4
            agent_config['mhdpa_nbr_rec_update'] = 1
            agent_config['mhdpa_nbr_mlp_unit'] = 256
            agent_config['mhdpa_interaction_dim'] = 128

        agent_config['temporal_encoder_nbr_hidden_units'] = 0
        agent_config['temporal_encoder_nbr_rnn_layers'] = 0
        agent_config['temporal_encoder_mini_batch_size'] = args.mini_batch_size
        agent_config['symbol_processing_nbr_hidden_units'] = agent_config[
            'temporal_encoder_nbr_hidden_units']
        agent_config['symbol_processing_nbr_rnn_layers'] = 1

    else:
        raise NotImplementedError

    save_path = f"./Example/{args.dataset}+DualLabeled"
    save_path += f"/{nbr_epoch}Ep_Emb{args.symbol_embedding_size}_CNN{cnn_feature_size}"
    if args.shared_architecture:
        save_path += "/shared_architecture/"
    save_path += f"Dropout{args.dropout_prob}_DPEmb{args.embedding_dropout_prob}"
    save_path += f"_{rg_config['agent_learning']}"
    save_path += f"_{rg_config['agent_loss_type']}"

    if 'dSprites' in args.dataset:
        train_test_strategy = f"-{test_split_strategy}"
        if test_split_strategy != train_split_strategy:
            train_test_strategy = f"/train_{train_split_strategy}/test_{test_split_strategy}"
        save_path += f"/dSprites{train_test_strategy}"

    save_path += f"/OBS{rg_config['stimulus_resize_dim']}X{rg_config['stimulus_depth_dim']}C"

    save_path += '-{}{}Agent-SEED{}-{}-obs_b{}_minib{}_lr{}-{}-tau0-{}-{}DistrTrain{}Test{}-stim{}-vocab{}over{}_{}'.\
      format(
      'ObjectCentric' if rg_config['object_centric'] else '',
      'Descriptive{}'.format(rg_config['descriptive_target_ratio']) if rg_config['descriptive'] else '',
      seed,
      rg_config['observability'],
      rg_config['batch_size'],
      args.mini_batch_size,
      rg_config['learning_rate'],
      rg_config['graphtype'],
      rg_config['tau0'],
      rg_config['distractor_sampling'],
      *rg_config['nbr_distractors'].values(),
      rg_config['nbr_stimulus'],
      rg_config['vocab_size'],
      rg_config['max_sentence_length'],
      rg_config['agent_architecture'],
    )

    if rg_config['use_feat_converter']:
        save_path += f"+FEATCONV"

    if rg_config['use_homoscedastic_multitasks_loss']:
        save_path += '+H**o'

    save_path += f"/{args.optimizer_type}/"

    if 'reinforce' in args.graphtype:
        save_path += f'/REINFORCE_EntropyCoeffNeg1m3/UnnormalizedDetLearningSignalHavrylovLoss/'

    save_path += f"withPopulationHandlerModule/STGS-{args.agent_type}-LSTM-CNN-Agent/"

    save_path += f"Periodic{args.metric_epoch_period}TS+DISComp-{'fast-' if args.metric_fast else ''}/"

    rg_config['save_path'] = save_path

    print(save_path)

    from ReferentialGym.utils import statsLogger
    logger = statsLogger(path=save_path, dumpPeriod=100)

    # # Agents
    batch_size = 4
    nbr_distractors = 1 if 'partial' in rg_config[
        'observability'] else agent_config['nbr_distractors']['train']
    nbr_stimulus = agent_config['nbr_stimulus']
    obs_shape = [
        nbr_distractors + 1, nbr_stimulus, rg_config['stimulus_depth_dim'],
        rg_config['stimulus_resize_dim'], rg_config['stimulus_resize_dim']
    ]
    vocab_size = rg_config['vocab_size']
    max_sentence_length = rg_config['max_sentence_length']

    if 'Baseline' in args.agent_type:
        from ReferentialGym.agents import LSTMCNNSpeaker
        speaker = LSTMCNNSpeaker(kwargs=agent_config,
                                 obs_shape=obs_shape,
                                 vocab_size=vocab_size,
                                 max_sentence_length=max_sentence_length,
                                 agent_id='s0',
                                 logger=logger)
    elif 'EoSPriored' in args.agent_type:
        from ReferentialGym.agents import EoSPrioredLSTMCNNSpeaker
        speaker = EoSPrioredLSTMCNNSpeaker(
            kwargs=agent_config,
            obs_shape=obs_shape,
            vocab_size=vocab_size,
            max_sentence_length=max_sentence_length,
            agent_id='s0',
            logger=logger)
    print("Speaker:", speaker)

    listener_config = copy.deepcopy(agent_config)
    if args.shared_architecture:
        listener_config['cnn_encoder'] = speaker.cnn_encoder
    listener_config['nbr_distractors'] = rg_config['nbr_distractors']['train']
    batch_size = 4
    nbr_distractors = listener_config['nbr_distractors']
    nbr_stimulus = listener_config['nbr_stimulus']
    obs_shape = [
        nbr_distractors + 1, nbr_stimulus, rg_config['stimulus_depth_dim'],
        rg_config['stimulus_resize_dim'], rg_config['stimulus_resize_dim']
    ]
    vocab_size = rg_config['vocab_size']
    max_sentence_length = rg_config['max_sentence_length']

    from ReferentialGym.agents import LSTMCNNListener
    listener = LSTMCNNListener(kwargs=listener_config,
                               obs_shape=obs_shape,
                               vocab_size=vocab_size,
                               max_sentence_length=max_sentence_length,
                               agent_id='l0',
                               logger=logger)
    print("Listener:", listener)

    # # Dataset:
    need_dict_wrapping = []
    if 'CIFAR10' in args.dataset:
        train_dataset = torchvision.datasets.CIFAR10(
            root='./datasets/CIFAR10/',
            train=True,
            transform=rg_config['train_transform'],
            download=True)
        test_dataset = torchvision.datasets.CIFAR10(
            root='./datasets/CIFAR10/',
            train=False,
            transform=rg_config['test_transform'],
            download=True)
        need_dict_wrapping = ['train', 'test']
    elif 'dSprites' in args.dataset:
        root = './datasets/dsprites-dataset'
        train_dataset = ReferentialGym.datasets.dSpritesDataset(
            root=root,
            train=True,
            transform=rg_config['train_transform'],
            split_strategy=train_split_strategy)
        test_dataset = ReferentialGym.datasets.dSpritesDataset(
            root=root,
            train=False,
            transform=rg_config['test_transform'],
            split_strategy=test_split_strategy)
    else:
        raise NotImplementedError

    ## Modules:
    modules = {}

    from ReferentialGym import modules as rg_modules

    # Population:
    population_handler_id = "population_handler_0"
    population_handler_config = rg_config
    population_handler_stream_ids = {
        "modules:current_speaker": "current_speaker_streams_dict",
        "modules:current_listener": "current_listener_streams_dict",
        "signals:epoch": "epoch",
        "signals:mode": "mode",
        "signals:global_it_datasample": "global_it_datasample",
    }

    # Current Speaker:
    current_speaker_id = "current_speaker"
    # Current Listener:
    current_listener_id = "current_listener"

    modules[population_handler_id] = rg_modules.build_PopulationHandlerModule(
        id=population_handler_id,
        prototype_speaker=speaker,
        prototype_listener=listener,
        config=population_handler_config,
        input_stream_ids=population_handler_stream_ids)
    modules[current_speaker_id] = rg_modules.CurrentAgentModule(
        id=current_speaker_id, role="speaker")
    modules[current_listener_id] = rg_modules.CurrentAgentModule(
        id=current_listener_id, role="listener")

    if args.homoscedastic_multitasks_loss:
        homo_id = "homo0"
        homo_config = {"use_cuda": args.use_cuda}
        modules[homo_id] = rg_modules.build_HomoscedasticMultiTasksLossModule(
            id=homo_id,
            config=homo_config,
        )

    ## Pipelines:
    pipelines = {}

    # 0)  Now that all the trainable modules are known,
    #     let's build the optimization module:
    optim_id = "global_optim"
    optim_config = {
        "modules": modules,
        "learning_rate": args.lr,
        "optimizer_type": args.optimizer_type,
        "with_gradient_clip": rg_config["with_gradient_clip"],
        "adam_eps": rg_config["adam_eps"],
    }

    optim_module = rg_modules.build_OptimizationModule(
        id=optim_id,
        config=optim_config,
    )
    modules[optim_id] = optim_module

    grad_recorder_id = "grad_recorder"
    grad_recorder_module = rg_modules.build_GradRecorderModule(
        id=grad_recorder_id)
    modules[grad_recorder_id] = grad_recorder_module

    topo_sim_metric_id = "topo_sim_metric"
    topo_sim_metric_module = rg_modules.build_TopographicSimilarityMetricModule(
        id=topo_sim_metric_id,
        config={
            "parallel_TS_computation_max_workers": 16,
            "epoch_period": args.metric_epoch_period,
            "fast": args.metric_fast,
            "verbose": False,
            "vocab_size": rg_config["vocab_size"],
        })
    modules[topo_sim_metric_id] = topo_sim_metric_module

    inst_coord_metric_id = "inst_coord_metric"
    inst_coord_metric_module = rg_modules.build_InstantaneousCoordinationMetricModule(
        id=inst_coord_metric_id, config={
            "epoch_period": 1,
        })
    modules[inst_coord_metric_id] = inst_coord_metric_module

    speaker_factor_vae_disentanglement_metric_id = "speaker_factor_vae_disentanglement_metric"
    speaker_factor_vae_disentanglement_metric_input_stream_ids = {
        'modules:current_speaker:ref:ref_agent:cnn_encoder': 'model',
        'modules:current_speaker:ref:ref_agent:features': 'representations',
        'current_dataloader:sample:speaker_experiences': 'experiences',
        'current_dataloader:sample:speaker_exp_latents':
        'latent_representations',
        'current_dataloader:sample:speaker_exp_latents_values':
        'latent_values_representations',
        'current_dataloader:sample:speaker_indices': 'indices',
    }
    speaker_factor_vae_disentanglement_metric_module = rg_modules.build_FactorVAEDisentanglementMetricModule(
        id=speaker_factor_vae_disentanglement_metric_id,
        input_stream_ids=
        speaker_factor_vae_disentanglement_metric_input_stream_ids,
        config={
            "epoch_period": args.metric_epoch_period,
            "batch_size": 64,  #5,
            "nbr_train_points": 10000,  #3000,
            "nbr_eval_points": 5000,  #2000,
            "resample": False,
            "threshold": 5e-2,  #0.0,#1.0,
            "random_state_seed": args.seed,
            "verbose": False,
            "active_factors_only": True,
        })
    modules[
        speaker_factor_vae_disentanglement_metric_id] = speaker_factor_vae_disentanglement_metric_module

    listener_factor_vae_disentanglement_metric_id = "listener_factor_vae_disentanglement_metric"
    listener_factor_vae_disentanglement_metric_input_stream_ids = {
        'modules:current_listener:ref:ref_agent:cnn_encoder': 'model',
        'modules:current_listener:ref:ref_agent:features': 'representations',
        'current_dataloader:sample:listener_experiences': 'experiences',
        'current_dataloader:sample:listener_exp_latents':
        'latent_representations',
        'current_dataloader:sample:listener_exp_latents_values':
        'latent_values_representations',
        'current_dataloader:sample:listener_indices': 'indices',
    }
    listener_factor_vae_disentanglement_metric_module = rg_modules.build_FactorVAEDisentanglementMetricModule(
        id=listener_factor_vae_disentanglement_metric_id,
        input_stream_ids=
        listener_factor_vae_disentanglement_metric_input_stream_ids,
        config={
            "epoch_period": args.metric_epoch_period,
            "batch_size": 64,  #5,
            "nbr_train_points": 10000,  #3000,
            "nbr_eval_points": 5000,  #2000,
            "resample": False,
            "threshold": 5e-2,  #0.0,#1.0,
            "random_state_seed": args.seed,
            "verbose": False,
            "active_factors_only": True,
        })
    modules[
        listener_factor_vae_disentanglement_metric_id] = listener_factor_vae_disentanglement_metric_module

    logger_id = "per_epoch_logger"
    logger_module = rg_modules.build_PerEpochLoggerModule(id=logger_id)
    modules[logger_id] = logger_module

    pipelines['referential_game'] = [
        population_handler_id, current_speaker_id, current_listener_id
    ]

    pipelines[optim_id] = []
    if args.homoscedastic_multitasks_loss:
        pipelines[optim_id].append(homo_id)
    pipelines[optim_id].append(optim_id)
    '''
  # Add gradient recorder module for debugging purposes:
  pipelines[optim_id].append(grad_recorder_id)
  '''
    pipelines[optim_id].append(speaker_factor_vae_disentanglement_metric_id)
    pipelines[optim_id].append(listener_factor_vae_disentanglement_metric_id)
    pipelines[optim_id].append(topo_sim_metric_id)
    pipelines[optim_id].append(inst_coord_metric_id)
    pipelines[optim_id].append(logger_id)

    rg_config["modules"] = modules
    rg_config["pipelines"] = pipelines

    dataset_args = {
        "dataset_class": "DualLabeledDataset",
        "modes": {
            "train": train_dataset,
            "test": test_dataset,
        },
        "need_dict_wrapping": need_dict_wrapping,
        "nbr_stimulus": rg_config['nbr_stimulus'],
        "distractor_sampling": rg_config['distractor_sampling'],
        "nbr_distractors": rg_config['nbr_distractors'],
        "observability": rg_config['observability'],
        "object_centric": rg_config['object_centric'],
        "descriptive": rg_config['descriptive'],
        "descriptive_target_ratio": rg_config['descriptive_target_ratio'],
    }

    refgame = ReferentialGym.make(config=rg_config, dataset_args=dataset_args)

    refgame.train(nbr_epoch=nbr_epoch, logger=logger, verbose_period=1)

    logger.flush()