Ejemplo n.º 1
0
    def __init__(self, inArgs):
        torch.manual_seed(inArgs['dsId'])  #seed becomes 'dataset Id'
        torch.cuda.manual_seed(inArgs['dsId'])

        self.args = {}
        self.args["batch_size"] = inArgs["batch_size"]
        self.args["pretrainEpochs"] = inArgs["pretrainEpochs"]
        self.args["trainEpochs"] = inArgs["trainEpochs"]
        self.args['cuda'] = inArgs["cuda"]
        self.args['device'] = inArgs['device']
        self.args['modelLoc'] = inArgs[
            'modelLoc']  #Ben: the location of 'trained model'
        self.args["optim"] = inArgs["optim"]

        # Network Arguments
        self.args['dropout_prob'] = inArgs["dropout"]
        self.args['encoder_size'] = 64
        self.args['decoder_size'] = 128
        self.args['in_length'] = inArgs[
            'input_size']  #Ben: INPUT (history length)
        self.args['out_length'] = inArgs[
            'output_size']  #Ben: OUTPUT (output seq length)
        self.args['grid_size'] = (13, 3)
        self.args['upp_grid_size'] = (7, 3)
        self.args['soc_conv_depth'] = 64
        self.args['conv_3x1_depth'] = 16
        self.args['dyn_embedding_size'] = 32
        self.args['input_embedding_size'] = 32
        self.args['num_lat_classes'] = 3
        self.args['num_lon_classes'] = 2
        self.args['use_maneuvers'] = inArgs["maneuvers"]
        self.args['ours'] = (inArgs["predAlgo"] == "Traphic")
        self.args['nll_only'] = inArgs['nll_only']
        self.args["learning_rate"] = inArgs["lr"]
        self.args["predAlgo"] = inArgs["predAlgo"]  #TraPHic
        self.args["w_decay"] = inArgs['w_decay']

        # currentDT = datetime.datetime.now()
        # self.args['name'] = "{}_{}_model.tar".format(inArgs["predAlgo"], currentDT.strftime("%Y_%m_%d_%H_%M"))
        self.args['name'] = inArgs['name_temp'].format(self.args["predAlgo"],
                                                       inArgs['dset'])
        self.args["pretrain_loss"] = inArgs['pretrain_loss']
        self.args['train_loss'] = inArgs['train_loss']
        self.args['dir'] = inArgs['dir']
        self.args['raw_dir'] = inArgs['raw_dir']
        self.args['dsId'] = inArgs['dsId']
        self.args['log_dir'] = inArgs['log_dir']
        self.args['tensorboard'] = inArgs['tensorboard']
        self.args['class_type'] = inArgs['class_type']
        if self.args["predAlgo"] == "Traphic":  # Ben: Declare the network
            self.net = traphicNet(self.args)
        else:
            self.net = highwayNet(self.args)

        if self.args['cuda']:
            self.net = self.net.cuda(self.args['device'])
    def __init__(self, inArgs):
        self.args = {}
        self.args["batch_size"] = inArgs["batch_size"]
        self.args["pretrainEpochs"] = inArgs["pretrainEpochs"]
        self.args["trainEpochs"] = inArgs["trainEpochs"]
        self.args['cuda'] = inArgs["cuda"]
        # args['cuda'] = False
        self.args['device'] = inArgs['device']
        self.args['modelLoc'] = inArgs['modelLoc']
        self.args["optim"] = inArgs["optim"]

        # Network Arguments
        self.args['dropout_prob'] = inArgs["dropout"]
        self.args['encoder_size'] = 64
        self.args['decoder_size'] = 128
        self.args['in_length'] = inArgs['input_size']
        self.args['out_length'] = inArgs['output_size']
        self.args['grid_size'] = (13, 3)
        self.args['upp_grid_size'] = (7, 3)
        self.args['soc_conv_depth'] = 64
        self.args['conv_3x1_depth'] = 16
        self.args['dyn_embedding_size'] = 32
        self.args['input_embedding_size'] = 32
        self.args['num_lat_classes'] = 3
        self.args['num_lon_classes'] = 2
        self.args['use_maneuvers'] = inArgs["maneuvers"]
        self.args['ours'] = (inArgs["predAlgo"] == "Traphic")
        self.args['nll_only'] = True
        self.args["learning_rate"] = inArgs["lr"]
        self.args["predAlgo"] = inArgs["predAlgo"]

        # currentDT = datetime.datetime.now()
        # self.args['name'] = "{}_{}_model.tar".format(inArgs["predAlgo"], currentDT.strftime("%Y_%m_%d_%H_%M"))
        self.args['name'] = inArgs['name_temp'].format(self.args["predAlgo"],
                                                       inArgs['dset'])
        self.args["pretrain_loss"] = inArgs['pretrain_loss']
        self.args['train_loss'] = inArgs['train_loss']
        self.args['dir'] = inArgs['dir']
        self.args['raw_dir'] = inArgs['raw_dir']
        self.args['dsId'] = 0
        self.args['log_dir'] = inArgs['log_dir']
        self.args['tensorboard'] = inArgs['tensorboard']
        if self.args["predAlgo"] == "Traphic":
            self.net = traphicNet(self.args)
        else:
            self.net = highwayNet(self.args)

        if self.args['cuda']:
            self.net = self.net.cuda(self.args['device'])
        print(self.net)
Ejemplo n.º 3
0
    def evaluate(self, viewArgs, thread=None):
        if thread:
            thread.signalCanvas("\n[INFO] Evaluation started...")
        args = self.getPredArgs(viewArgs)
        args['eval'] = True
        predAlgo = viewArgs["predAlgo"]
        optimSelection = viewArgs["optim"]

        if predAlgo == "Traphic":
            if thread:
                thread.signalCanvas(
                    "\n[INFO]: Using Traphic for the saved model")
            args['train_flag'] = False
            args["ours"] = True
            net = traphicNet(args)
        elif predAlgo == "Social GAN":
            sganEval(sgan_args, thread)
            return
        elif predAlgo == "Social-LSTM":
            print(predAlgo)
        elif predAlgo == "Social Conv":
            if thread:
                thread.signalCanvas(
                    "\n[INFO]: Using Convolutional Social Pooling")
            args['train_flag'] = False
            net = highwayNet(args)

        net.eval()
        d = os.path.join(args['modelLoc'])

        if thread:
            thread.signalCanvas(d)

        if os.path.exists(d):
            net.load_state_dict(torch.load(d))
            if thread:
                thread.signalCanvas("\n[INFO]: model loaded")
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: can not find model to evaluate")

        if args["cuda"]:
            if thread:
                thread.signalCanvas("\n[INFO]: Using CUDA")
            net.cuda()

        if optimSelection == "Adam":
            optim = torch.optim.Adam(net.parameters(),
                                     lr=args['learning_rate'])
            if thread:
                thread.signalCanvas("\n[INFO]: Optimizer: \n" + str(optim))
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED")
            return

        crossEnt = torch.nn.BCELoss()
        if thread:
            thread.signalCanvas("\n[INFO]: Loss: \n" + str(crossEnt))

        # TODO: More hardcodes
        dataset_name = viewArgs["dir"].split('/')[2]
        prediction_data_path = 'model/Prediction/data/{}'.format(dataset_name)
        trSet_path = os.path.join(prediction_data_path, "TrainSet.npy")
        valSet_path = os.path.join(prediction_data_path, "ValSet.npy")
        tstSet_path = os.path.join(prediction_data_path, "TestSet.npy")

        trSet = ngsimDataset(trSet_path)
        trDataloader = DataLoader(trSet,
                                  batch_size=args['batch_size'],
                                  shuffle=True,
                                  num_workers=8,
                                  collate_fn=trSet.collate_fn)

        testSet = ngsimDataset(valSet_path)
        testDataloader = DataLoader(testSet,
                                    batch_size=args['batch_size'],
                                    shuffle=True,
                                    num_workers=8,
                                    collate_fn=testSet.collate_fn)

        valSet = ngsimDataset(tstSet_path)
        valDataloader = DataLoader(valSet,
                                   batch_size=args['batch_size'],
                                   shuffle=True,
                                   num_workers=8,
                                   collate_fn=valSet.collate_fn)

        if predAlgo == "Traphic":
            args["ours"] = True
            engine = TraphicEngine(net, optim, trDataloader, testDataloader,
                                   args, thread)
        elif predAlgo == "Social Conv":
            engine = SocialEngine(net, optim, trDataloader, testDataloader,
                                  args, thread)
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED")
        if thread:
            thread.signalCanvas(
                "\n[INFO]: *** Evaluating Prediction Model ***")

        engine.eval(testDataloader)
Ejemplo n.º 4
0
    def train(self, viewArgs, thread=None):
        if thread:
            thread.signalCanvas("\n[INFO] Training started...")

        args = self.getPredArgs(viewArgs)
        args['eval'] = False
        predAlgo = viewArgs["predAlgo"]
        optimSelection = viewArgs["optim"]

        if predAlgo == "Traphic":
            if thread:
                thread.signalCanvas("\n[INFO]: Using TRAPHIC model")
            args["ours"] = True
            args['train_flag'] = True
            net = traphicNet(args)
        elif predAlgo == "Social GAN":
            sgan_args.num_epochs = int(args["pretrainEpochs"]) + int(
                args["trainEpochs"])
            sgan_args.batch_size = args["batch_size"]
            sgan_args.dropout = args['dropout_prob']
            sgan_args.g_learning_rate = args["learning_rate"]
            sgan_args.g_learning_rate = args["learning_rate"]
            if thread:
                thread.signalCanvas("\n[INFO]: Using Sgan model")
                thread.signalCanvas(
                    "\n[INFO]: *** Training Prediction Model ***")
            sganTrain(sgan_args, thread)
            return
        elif predAlgo == "Social-LSTM":
            print(predAlgo)
        elif predAlgo == "Social Conv":
            if thread:
                thread.signalCanvas(
                    "\n[INFO]: Using Convolutional Social Pooling")
            args['train_flag'] = True
            net = highwayNet(args)

        if args["cuda"]:
            if thread:
                thread.signalCanvas("\n[INFO]: Using CUDA")
            net.cuda()

        if optimSelection == "Adam":
            optim = torch.optim.Adam(net.parameters(),
                                     lr=args['learning_rate'])
            if thread:
                thread.signalCanvas("\n[INFO]: Optimizer: \n" + str(optim))
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED")
            return

        crossEnt = torch.nn.BCELoss()
        if thread:
            thread.signalCanvas("\n[INFO]: Loss: \n" + str(crossEnt))

        # name of
        dataset_name = viewArgs["dir"].split('/')[2]
        prediction_data_path = 'model/Prediction/data/{}'.format(dataset_name)
        trSet_path = os.path.join(prediction_data_path, "TrainSet.npy")
        valSet_path = os.path.join(prediction_data_path, "ValSet.npy")
        trSet = ngsimDataset(trSet_path)
        valSet = ngsimDataset(valSet_path)

        trDataloader = DataLoader(trSet,
                                  batch_size=args['batch_size'],
                                  shuffle=True,
                                  num_workers=8,
                                  collate_fn=trSet.collate_fn)
        valDataloader = DataLoader(valSet,
                                   batch_size=args['batch_size'],
                                   shuffle=True,
                                   num_workers=8,
                                   collate_fn=valSet.collate_fn)

        if predAlgo == "Traphic":
            args["ours"] = True
            engine = TraphicEngine(net, optim, trDataloader, valDataloader,
                                   args, thread)
        elif predAlgo == "Social Conv":
            engine = SocialEngine(net, optim, trDataloader, valDataloader,
                                  args, thread)
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED")
        if thread:
            thread.signalCanvas("\n[INFO]: *** Training Prediction Model ***")
        engine.start()