示例#1
0
    def __init__(
        self,
        model,
        state=None,
        epoch=0,
        name="",
        run_name="",
    ):
        _, _, _, corpus = get_all_instructions()
        self.token2word, self.word2token = get_word_to_token_map(corpus)

        self.params = get_current_parameters()["Training"]
        self.batch_size = self.params['batch_size']
        self.weight_decay = self.params['weight_decay']
        self.optimizer = self.params['optimizer']
        self.num_loaders = self.params['num_loaders']
        self.lr = self.params['lr']

        self.name = name
        self.dataset_names = None

        n_params = get_n_params(model)
        n_params_tr = get_n_trainable_params(model)
        print("Training Model:")
        print("Number of model parameters: " + str(n_params))
        print("Trainable model parameters: " + str(n_params_tr))

        self.model = model
        self.run_name = run_name
        if self.optimizer == "adam":
            self.optim = optim.Adam(self.get_model_parameters(self.model),
                                    self.lr,
                                    weight_decay=self.weight_decay)
        elif self.optimizer == "sgd":
            self.optim = optim.SGD(self.get_model_parameters(self.model),
                                   self.lr,
                                   weight_decay=self.weight_decay,
                                   momentum=0.9)
        self.train_epoch_num = epoch
        self.train_segment = 0
        self.test_epoch_num = epoch
        self.test_segment = 0
        self.set_state(state)
        self.batch_num = 0
示例#2
0
    def __init__(self,
                 model_real,
                 model_sim,
                 model_critic,
                 model_oracle_critic=None,
                 state=None,
                 epoch=0):
        _, _, _, corpus = get_all_instructions()
        self.token2word, self.word2token = get_word_to_token_map(corpus)

        self.params = get_current_parameters()["Training"]
        self.run_name = get_current_parameters()["Setup"]["run_name"]
        self.batch_size = self.params['batch_size']
        self.iterations_per_epoch = self.params.get("iterations_per_epoch",
                                                    None)
        self.weight_decay = self.params['weight_decay']
        self.optimizer = self.params['optimizer']
        self.critic_loaders = self.params['critic_loaders']
        self.model_common_loaders = self.params['model_common_loaders']
        self.model_sim_loaders = self.params['model_sim_loaders']
        self.lr = self.params['lr']
        self.critic_steps = self.params['critic_steps']
        self.model_steps = self.params['model_steps']
        self.critic_batch_size = self.params["critic_batch_size"]
        self.model_batch_size = self.params["model_batch_size"]
        self.disable_wloss = self.params["disable_wloss"]
        self.sim_steps_per_real_step = self.params.get(
            "sim_steps_per_real_step", 1)

        self.real_dataset_names = self.params.get("real_dataset_names")
        self.sim_dataset_names = self.params.get("sim_dataset_names")

        self.bidata = self.params.get("bidata", False)
        self.sim_steps_per_common_step = self.params.get(
            "sim_steps_per_common_step", 1)

        n_params_real = get_n_params(model_real)
        n_params_real_tr = get_n_trainable_params(model_real)
        n_params_sim = get_n_params(model_sim)
        n_params_sim_tr = get_n_trainable_params(model_sim)
        n_params_c = get_n_params(model_critic)
        n_params_c_tr = get_n_params(model_critic)

        print("Training Model:")
        print("Real # model parameters: " + str(n_params_real))
        print("Real # trainable parameters: " + str(n_params_real_tr))
        print("Sim  # model parameters: " + str(n_params_sim))
        print("Sim  # trainable parameters: " + str(n_params_sim_tr))
        print("Critic  # model parameters: " + str(n_params_c))
        print("Critic  # trainable parameters: " + str(n_params_c_tr))

        # Share those modules that are to be shared between real and sim models
        if not self.params.get("disable_domain_weight_sharing"):
            print("Sharing weights between sim and real modules")
            model_real.steal_cross_domain_modules(model_sim)
        else:
            print("NOT Sharing weights between sim and real modules")

        self.model_real = model_real
        self.model_sim = model_sim
        self.model_critic = model_critic
        self.model_oracle_critic = model_oracle_critic
        if self.model_oracle_critic:
            print("Using oracle critic")

        if self.optimizer == "adam":
            Optim = optim.Adam
        elif self.optimizer == "sgd":
            Optim = optim.SGD
        else:
            raise ValueError(f"Unsuppored optimizer {self.optimizer}")

        self.optim_models = Optim(self.model_real.both_domain_parameters(
            self.model_sim),
                                  self.lr,
                                  weight_decay=self.weight_decay)
        self.optim_critic = Optim(self.critic_parameters(),
                                  self.lr,
                                  weight_decay=self.weight_decay)

        self.train_epoch_num = epoch
        self.train_segment = 0
        self.test_epoch_num = epoch
        self.test_segment = 0
        self.set_state(state)
示例#3
0
文件: models.py 项目: hyzcn/drif
def load_model(model_file_override=None,
               real=False,
               model_name_override=False):

    setup = P.get_current_parameters()["Setup"]
    model_name = model_name_override or setup["model"]
    model_file = model_file_override or setup["model_file"] or None
    cuda = setup["cuda"]
    run_name = setup["run_name"]

    model = None
    pytorch_model = False

    # -----------------------------------------------------------------------------------------------------------------
    # Oracles / baselines that ignore images
    # -----------------------------------------------------------------------------------------------------------------

    if model_name == "oracle":
        rollout_params = get_current_parameters()["Rollout"]
        if rollout_params["oracle_type"] == "SimpleCarrotPlanner":
            model = SimpleCarrotPlanner()
            print("Using simple carrot planner")
        elif rollout_params["oracle_type"] == "BasicCarrotPlanner":
            model = BasicCarrotPlanner()
            print("Using basic carrot planner")
        elif rollout_params["oracle_type"] == "FancyCarrotPlanner":
            model = FancyCarrotPlanner()
            print("Using fancy carrot planner")
        else:
            print("UNKNOWN ORACLE: ", rollout_params["OracleType"])
            exit(-1)
    elif model_name == "baseline_straight":
        model = BaselineStraight()
    elif model_name == "baseline_stop":
        model = BaselineStop()

    # -----------------------------------------------------------------------------------------------------------------
    # FASTER RSS 2018 Resubmission Model
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "gsmn":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jlang":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=False)
        pytorch_model = True
    elif model_name == "gsmn_wo_jgnd":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=False,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jclass":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=False,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jgoal":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=False,
                         aux_lang=True)
        pytorch_model = True

    elif model_name == "gsmn_w_posnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=True,
                         rot_noise=False)
        pytorch_model = True
    elif model_name == "gsmn_w_rotnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=False,
                         rot_noise=True)
        pytorch_model = True
    elif model_name == "gsmn_w_bothnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=True,
                         rot_noise=True)
        pytorch_model = True

    elif model_name == "gs_fpv":
        model = ModelGSFPV(run_name,
                           aux_class_features=True,
                           aux_grounding_features=True,
                           aux_lang=True,
                           recurrence=False)
        pytorch_model = True
    elif model_name == "gs_fpv_mem":
        model = ModelGSFPV(run_name,
                           aux_class_features=True,
                           aux_grounding_features=True,
                           aux_lang=True,
                           recurrence=True)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Model
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "sm_traj_nav_ratio":
        model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV)
        pytorch_model = True
    elif model_name == "sm_traj_nav_ratio_path":
        model = ModelTrajectoryProbRatio(run_name,
                                         model_class=mtpr.PVN_STAGE1_ONLY)
        pytorch_model = True

    elif model_name == "action_gtr":
        model = ModelTrajectoryToAction(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Refactored
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "pvn_full":
        model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV)
        pytorch_model = True
    elif model_name == "pvn_stage1":
        model = ModelTrajectoryProbRatio(run_name,
                                         model_class=mtpr.PVN_STAGE1_ONLY)
        pytorch_model = True
    elif model_name == "pvn_stage2":
        model = ModelTrajectoryToAction(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Top-Down Full Observability Models
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "top_down_goal_batched":
        model = ModelTopDownPathGoalPredictorBatched(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL Baselines
    # -----------------------------------------------------------------------------------------------------------------
    elif model_name == "chaplot":
        model = ModelChaplot(run_name)
        pytorch_model = True

    elif model_name == "misra2017":
        model = ModelMisra2017(run_name)
        pytorch_model = True

    model_loaded = False
    if pytorch_model:
        n_params = get_n_params(model)
        n_params_tr = get_n_trainable_params(model)
        print("Loaded PyTorch model!")
        print("Number of model parameters: " + str(n_params))
        print("Trainable model parameters: " + str(n_params_tr))
        model.init_weights()
        model.eval()
        if model_file:
            load_pytorch_model(model, model_file)
            print("Loaded previous model: ", model_file)
            model_loaded = True
        if cuda:
            model = model.cuda()

    return model, model_loaded
示例#4
0
    def __init__(self,
                 run_name,
                 ignore_lang=False,
                 class_loss=True,
                 ground_loss=True):
        super(ModelTopDownPathGoalPredictor, self).__init__()
        self.run_name = run_name
        self.model_name = "top_down_path_pred_pretrain"
        self.writer = SummaryWriter(log_dir="runs/" + run_name)

        self.ignore_lang = ignore_lang
        self.class_loss = class_loss
        self.ground_loss = ground_loss

        # The feature net extracts the 2D feature map from the input image.
        # The label_pool down-sizes the ground-truth labels, which are input at the same size as the input image
        # The output predicted labels are the size of the feature map
        self.feature_net = ResNet13Light(32, down_pad=True)
        self.label_pool = nn.MaxPool2d(8)

        if self.ground_loss:
            self.lang_filter = MapLangSemanticFilter(sentence_embedding_size,
                                                     32, 3)
            self.aux_ground_linear = nn.Linear(3, 2)
            enable_weight_saving(self.lang_filter, "ground_filter")
            enable_weight_saving(self.aux_ground_linear, "ground_aux_linear")

        if RESNET:
            self.unet = ResNetConditional(sentence_embedding_size, 35, 2)
        else:
            unet_c_in = 35 if self.ground_loss else 32
            unet_hc1 = 48 if self.ground_loss else 48
            unet_hb1 = 24 if self.ground_loss else 24
            self.unet = Unet5ContextualBneck(unet_c_in,
                                             2,
                                             sentence_embedding_size,
                                             hc1=unet_hc1,
                                             hb1=unet_hb1,
                                             hc2=128,
                                             split_embedding=splitemb)

        if attention:
            self.sentence_embedding = SentenceEmbeddingSelfAttention(
                word_embedding_size,
                lstm_size,
                sentence_embedding_layers,
                attention_heads=attention_heads)
        else:
            self.sentence_embedding = SentenceEmbeddingSimple(
                word_embedding_size, sentence_embedding_size,
                sentence_embedding_layers)

        self.gather2d = Gather2D()

        if self.class_loss:
            self.aux_class_linear = nn.Linear(32, 64)
            enable_weight_saving(self.aux_class_linear, "class_aux_linear")

        print("Sentence Embedding #Params: ",
              get_n_params(self.sentence_embedding))
        print("U-Net #Params: ", get_n_params(self.unet))
        print("Class auxiliary: ", self.class_loss)
        print("Ground auxiliary: ", self.ground_loss)

        # Enable saving of pre-trained weights
        enable_weight_saving(self.feature_net, "feature_resnet_light")
        enable_weight_saving(self.unet, "unet")
        enable_weight_saving(self.sentence_embedding, "sentence_embedding")

        if NLL:
            #self.mask_loss = nn.BCELoss()
            self.mask_loss = nn.NLLLoss2d()
        elif BCE:
            self.mask_loss = nn.BCEWithLogitsLoss()
        elif CE:
            self.spatialsoftmax = SpatialSoftmax2d()
            self.mask_loss = CrossEntropy2d()
        else:
            self.mask_loss = nn.MSELoss()

        self.aux_loss = nn.CrossEntropyLoss(reduce=True, size_average=True)
        self.epoch_numbers = {"train": 0, "eval": 0}
        self.iter = nn.Parameter(torch.zeros(1), requires_grad=False)

        self.dropout = nn.Dropout(0.5)
        self.dropout2d = nn.Dropout2d(0.5)
        self.dropout3d = nn.Dropout3d(0.5)

        self.viz_images = []
        self.instructions = []
示例#5
0
def load_model(model_name_override=False,
               model_file_override=None,
               domain="sim"):

    setup = P.get_current_parameters()["Setup"]
    model_name = model_name_override or setup["model"]
    model_file = model_file_override or setup["model_file"] or None
    # TODO: Move this stuff elsewhere and tidy up the model
    perception_model_file = setup.get("perception_model_file") or None
    perception_model_real = setup.get("perception_model_real") or None
    cuda = setup["cuda"]
    run_name = setup["run_name"]

    model = None
    pytorch_model = False

    # -----------------------------------------------------------------------------------------------------------------
    # Oracles / baselines that ignore images
    # -----------------------------------------------------------------------------------------------------------------

    if model_name == "oracle":
        rollout_params = get_current_parameters()["Rollout"]
        if rollout_params["oracle_type"] == "SimpleCarrotPlanner":
            model = SimpleCarrotPlanner()
            print("Using simple carrot planner")
        elif rollout_params["oracle_type"] == "BasicCarrotPlanner":
            model = BasicCarrotPlanner()
            print("Using basic carrot planner")
        elif rollout_params["oracle_type"] == "FancyCarrotPlanner":
            model = FancyCarrotPlanner()
            print("Using fancy carrot planner")
        else:
            print("UNKNOWN ORACLE: ", rollout_params["OracleType"])
            exit(-1)
    elif model_name == "average":
        model = BaselineAverage()
    elif model_name == "stop":
        model = BaselineStop()

    # -----------------------------------------------------------------------------------------------------------------
    # FASTER RSS 2018 Resubmission Model
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "gsmn":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jlang":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=False)
        pytorch_model = True
    elif model_name == "gsmn_wo_jgnd":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=False,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jclass":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=False,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True)
        pytorch_model = True
    elif model_name == "gsmn_wo_jgoal":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=False,
                         aux_lang=True)
        pytorch_model = True

    elif model_name == "gsmn_w_posnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=True,
                         rot_noise=False)
        pytorch_model = True
    elif model_name == "gsmn_w_rotnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=False,
                         rot_noise=True)
        pytorch_model = True
    elif model_name == "gsmn_w_bothnoise":
        model = ModelRSS(run_name,
                         model_class=ModelRSS.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=True,
                         pos_noise=True,
                         rot_noise=True)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # RSS Baselines
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "gs_fpv":
        model = ModelGSFPV(run_name,
                           aux_class_features=True,
                           aux_grounding_features=True,
                           aux_lang=True,
                           recurrence=False)
        pytorch_model = True
    elif model_name == "gs_fpv_mem":
        model = ModelGSFPV(run_name,
                           aux_class_features=True,
                           aux_grounding_features=True,
                           aux_lang=True,
                           recurrence=True)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # RSS Model for Cage
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "gsmn_cage":
        model = ModelRSS(run_name,
                         model_class=msrg.MODEL_RSS,
                         aux_class_features=False,
                         aux_grounding_features=False,
                         aux_class_map=True,
                         aux_grounding_map=True,
                         aux_goal_map=True,
                         aux_lang=False)
        pytorch_model = True

    elif model_name == "gsmn_bidomain":
        model = ModelGSMNBiDomain(run_name, model_instance_name=domain)
        pytorch_model = True

    elif model_name == "gsmn_critic":
        model = ModelGsmnCritic(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Model
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "sm_traj_nav_ratio":
        model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV)
        pytorch_model = True
    elif model_name == "sm_traj_nav_ratio_path":
        model = ModelTrajectoryProbRatio(run_name,
                                         model_class=mtpr.PVN_STAGE1_ONLY)
        pytorch_model = True

    elif model_name == "action_gtr":
        model = ModelTrajectoryToAction(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Refactored
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "pvn_full":
        model = ModelTrajectoryProbRatio(run_name, model_class=mtpr.MODEL_FPV)
        pytorch_model = True
    elif model_name == "pvn_stage1":
        model = ModelTrajectoryProbRatio(run_name,
                                         model_class=mtpr.PVN_STAGE1_ONLY)
        pytorch_model = True
    elif model_name == "pvn_stage2":
        model = ModelTrajectoryToAction(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL 2018 Top-Down Full Observability Models
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "top_down_goal_batched":
        model = ModelTopDownPathGoalPredictorBatched(run_name)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------
    # CoRL Model for cage (bidomain)
    # -----------------------------------------------------------------------------------------------------------------

    elif model_name == "pvn_original_stage1_bidomain":
        model = PVN_Stage1_Bidomain_Original(run_name, domain=domain)
        pytorch_model = True

    elif model_name == "pvn_stage1_bidomain":
        model = PVN_Stage1_Bidomain(run_name, domain=domain)
        pytorch_model = True

    elif model_name == "pvn_stage2_bidomain":
        model = PVN_Stage2_Bidomain(run_name, model_instance_name=domain)
        pytorch_model = True

    elif model_name == "pvn_stage2_actor_critic":
        model = PVN_Stage2_ActorCritic(run_name, model_instance_name=domain)
        pytorch_model = True

    elif model_name == "pvn_stage1_critic":
        model = PVN_Stage1_Critic(run_name)
        pytorch_model = True

    elif model_name == "pvn_stage1_critic_big":
        model = PVN_Stage1_Critic_Big(run_name)
        pytorch_model = True

    elif model_name == "pvn_full_bidomain":
        model = PVN_Wrapper_Bidomain(run_name,
                                     model_instance_name=domain,
                                     oracle_stage1=False)
        pytorch_model = True

    elif model_name == "pvn_full_bidomain_ground_truth":
        model = PVN_Wrapper_Bidomain(run_name,
                                     model_instance_name=domain,
                                     oracle_stage1=True)
        pytorch_model = True

    # -----------------------------------------------------------------------------------------------------------------

    model_loaded = False
    if pytorch_model:
        n_params = get_n_params(model)
        n_params_tr = get_n_trainable_params(model)
        print("Loaded PyTorch model!")
        print("Number of model parameters: " + str(n_params))
        print("Trainable model parameters: " + str(n_params_tr))
        model.init_weights()
        model.eval()
        if model_file:
            load_pytorch_model(model, model_file, pytorch3to4=True)
            print("Loaded previous model: ", model_file)
            model_loaded = True

        if cuda:
            model = model.cuda()

    return model, model_loaded