예제 #1
0
    def train(self, dsId=None):
        if dsId:
            self.args['dsId'] = dsId
    
        self.net.train_flag = True
        self.net.train()
        if self.args["cuda"]:
            self.net.cuda(self.args['device'])

        if self.args["optim"] == "Adam":
            optim = torch.optim.Adam(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "SGD":
            optim = torch.optim.SGD(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "AdamW":
            optim = torch.optim.AdamW(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "SparseAdam":
            optim = torch.optim.SparseAdam(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "Adamax":
            optim = torch.optim.Adamax(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "ASGD":
            optim = torch.optim.ASGD(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "Rprop":
            optim = torch.optim.Rprop(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "RMSprop":
            optim = torch.optim.RMSprop(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "LBFGS":
            optim = torch.optim.LBFGS(self.net.parameters(),lr=self.args['learning_rate'])
        else:
            print("undefined optimizer.")
            return

        crossEnt = torch.nn.BCELoss()

        print('loading data in {}...'.format(self.args['dsId']))
        trSet_path = os.path.join(self.args["dir"], "trainSet")
        valSet_path = os.path.join(self.args["dir"], "valSet")
        trSet = ngsimDataset(trSet_path, self.args["dir"], self.args["raw_dir"], 'train', self.args['dsId'], t_h=self.args['in_length'], t_f=self.args['out_length'])
        valSet = ngsimDataset(valSet_path, self.args["dir"], self.args["raw_dir"], 'val', self.args['dsId'], t_h=self.args['in_length'], t_f=self.args['out_length'])

        trDataloader = DataLoader(trSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=trSet.collate_fn)
        # trDataloader = DataLoader(valSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=valSet.collate_fn)
        valDataloader = DataLoader(valSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=valSet.collate_fn)

        print('start training {}...'.format(self.args["predAlgo"]))
        if self.args["predAlgo"] == "Traphic":
            engine = TraphicEngine(self.net, optim, trDataloader, valDataloader, self.args)
        else:
            engine = SocialEngine(self.net, optim, trDataloader, valDataloader, self.args)

        engine.start()
예제 #2
0
    def result_viz(self):  #TODO
        # This function is for visualizing the network output and ground truth trajectory
        # print('loading data in {}...'.format(self.args['dsId']))
        tstSet_path = os.path.join(self.args["dir"], "testSet")

        testSet = ngsimDataset(valSet_path,
                               self.args["dir"],
                               self.args["raw_dir"],
                               'val',
                               self.args['dsId'],
                               t_h=self.args['in_length'],
                               t_f=self.args['out_length'])
        testDataloader = DataLoader(testSet,
                                    batch_size=self.args['batch_size'],
                                    shuffle=True,
                                    num_workers=4,
                                    collate_fn=testSet.collate_fn)
        if self.args["predAlgo"] == "Traphic":
            engine = TraphicEngine(self.net, optim, trDataloader,
                                   valDataloader, self.args)
        else:
            engine = SocialEngine(self.net, optim, trDataloader, valDataloader,
                                  self.args)
예제 #3
0
    def evaluate(self, dsId=None):
        if dsId:
            self.args['dsId'] = dsId

        self.net.train_flag = False

        self.net.eval()
        d = os.path.join(self.args['modelLoc'], self.args['name'])

        if os.path.exists(d):
            self.net.load_state_dict(torch.load(d))
            print("\n[INFO]: model {} loaded".format(d))
        else:
            print("\n[INFO]: can not find model at {} to evaluate, using existing net".format(d))

        if self.args["cuda"]:
            self.net.cuda(self.args['device'])



        if self.args["optim"] == "Adam":
            optim = torch.optim.Adam(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "SGD":
            optim = torch.optim.SGD(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "AdamW":
            optim = torch.optim.AdamW(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "SparseAdam":
            optim = torch.optim.SparseAdam(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "Adamax":
            optim = torch.optim.Adamax(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "ASGD":
            optim = torch.optim.ASGD(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "Rprop":
            optim = torch.optim.Rprop(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "RMSprop":
            optim = torch.optim.RMSprop(self.net.parameters(),lr=self.args['learning_rate'])
        elif self.args["optim"] == "LBFGS":
            optim = torch.optim.LBFGS(self.net.parameters(),lr=self.args['learning_rate'])
        else:
            print("undefined optimizer.")
            return

        crossEnt = torch.nn.BCELoss()

        print('loading data in {}...'.format(self.args['dsId']))
        trSet_path = os.path.join(self.args["dir"], "trainSet")
        valSet_path = os.path.join(self.args["dir"], "valSet")
        tstSet_path = os.path.join(self.args["dir"], "testSet")

        trSet = ngsimDataset(trSet_path, self.args["dir"], self.args["raw_dir"], 'train', self.args['dsId'], t_h=self.args['in_length'], t_f=self.args['out_length'])
        trDataloader = DataLoader(trSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=trSet.collate_fn)

        testSet = ngsimDataset(valSet_path, self.args["dir"], self.args["raw_dir"], 'val', self.args['dsId'], t_h=self.args['in_length'], t_f=self.args['out_length'])
        testDataloader = DataLoader(testSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=testSet.collate_fn)

        valSet = ngsimDataset(tstSet_path, self.args["dir"], self.args["raw_dir"], 'val', self.args['dsId'], t_h=self.args['in_length'], t_f=self.args['out_length'])
        valDataloader = DataLoader(valSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=valSet.collate_fn)

        print('start testing {}...'.format(self.args["predAlgo"]))
        if self.args["predAlgo"] == "Traphic":
            engine = TraphicEngine(self.net, optim, trDataloader, valDataloader, self.args)
        else:
            engine = SocialEngine(self.net, optim, trDataloader, valDataloader, self.args)

        engine.eval(testDataloader)
예제 #4
0
    def evaluate(self, viewArgs, thread=None):
        if thread:
            thread.signalCanvas("\n[INFO] Evaluation started...")
        args = self.getPredArgs(viewArgs)
        args['eval'] = True
        predAlgo = viewArgs["predAlgo"]
        optimSelection = viewArgs["optim"]

        if predAlgo == "Traphic":
            if thread:
                thread.signalCanvas(
                    "\n[INFO]: Using Traphic for the saved model")
            args['train_flag'] = False
            args["ours"] = True
            net = traphicNet(args)
        elif predAlgo == "Social GAN":
            sganEval(sgan_args, thread)
            return
        elif predAlgo == "Social-LSTM":
            print(predAlgo)
        elif predAlgo == "Social Conv":
            if thread:
                thread.signalCanvas(
                    "\n[INFO]: Using Convolutional Social Pooling")
            args['train_flag'] = False
            net = highwayNet(args)

        net.eval()
        d = os.path.join(args['modelLoc'])

        if thread:
            thread.signalCanvas(d)

        if os.path.exists(d):
            net.load_state_dict(torch.load(d))
            if thread:
                thread.signalCanvas("\n[INFO]: model loaded")
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: can not find model to evaluate")

        if args["cuda"]:
            if thread:
                thread.signalCanvas("\n[INFO]: Using CUDA")
            net.cuda()

        if optimSelection == "Adam":
            optim = torch.optim.Adam(net.parameters(),
                                     lr=args['learning_rate'])
            if thread:
                thread.signalCanvas("\n[INFO]: Optimizer: \n" + str(optim))
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED")
            return

        crossEnt = torch.nn.BCELoss()
        if thread:
            thread.signalCanvas("\n[INFO]: Loss: \n" + str(crossEnt))

        # TODO: More hardcodes
        dataset_name = viewArgs["dir"].split('/')[2]
        prediction_data_path = 'model/Prediction/data/{}'.format(dataset_name)
        trSet_path = os.path.join(prediction_data_path, "TrainSet.npy")
        valSet_path = os.path.join(prediction_data_path, "ValSet.npy")
        tstSet_path = os.path.join(prediction_data_path, "TestSet.npy")

        trSet = ngsimDataset(trSet_path)
        trDataloader = DataLoader(trSet,
                                  batch_size=args['batch_size'],
                                  shuffle=True,
                                  num_workers=8,
                                  collate_fn=trSet.collate_fn)

        testSet = ngsimDataset(valSet_path)
        testDataloader = DataLoader(testSet,
                                    batch_size=args['batch_size'],
                                    shuffle=True,
                                    num_workers=8,
                                    collate_fn=testSet.collate_fn)

        valSet = ngsimDataset(tstSet_path)
        valDataloader = DataLoader(valSet,
                                   batch_size=args['batch_size'],
                                   shuffle=True,
                                   num_workers=8,
                                   collate_fn=valSet.collate_fn)

        if predAlgo == "Traphic":
            args["ours"] = True
            engine = TraphicEngine(net, optim, trDataloader, testDataloader,
                                   args, thread)
        elif predAlgo == "Social Conv":
            engine = SocialEngine(net, optim, trDataloader, testDataloader,
                                  args, thread)
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED")
        if thread:
            thread.signalCanvas(
                "\n[INFO]: *** Evaluating Prediction Model ***")

        engine.eval(testDataloader)
예제 #5
0
    def train(self, viewArgs, thread=None):
        if thread:
            thread.signalCanvas("\n[INFO] Training started...")

        args = self.getPredArgs(viewArgs)
        args['eval'] = False
        predAlgo = viewArgs["predAlgo"]
        optimSelection = viewArgs["optim"]

        if predAlgo == "Traphic":
            if thread:
                thread.signalCanvas("\n[INFO]: Using TRAPHIC model")
            args["ours"] = True
            args['train_flag'] = True
            net = traphicNet(args)
        elif predAlgo == "Social GAN":
            sgan_args.num_epochs = int(args["pretrainEpochs"]) + int(
                args["trainEpochs"])
            sgan_args.batch_size = args["batch_size"]
            sgan_args.dropout = args['dropout_prob']
            sgan_args.g_learning_rate = args["learning_rate"]
            sgan_args.g_learning_rate = args["learning_rate"]
            if thread:
                thread.signalCanvas("\n[INFO]: Using Sgan model")
                thread.signalCanvas(
                    "\n[INFO]: *** Training Prediction Model ***")
            sganTrain(sgan_args, thread)
            return
        elif predAlgo == "Social-LSTM":
            print(predAlgo)
        elif predAlgo == "Social Conv":
            if thread:
                thread.signalCanvas(
                    "\n[INFO]: Using Convolutional Social Pooling")
            args['train_flag'] = True
            net = highwayNet(args)

        if args["cuda"]:
            if thread:
                thread.signalCanvas("\n[INFO]: Using CUDA")
            net.cuda()

        if optimSelection == "Adam":
            optim = torch.optim.Adam(net.parameters(),
                                     lr=args['learning_rate'])
            if thread:
                thread.signalCanvas("\n[INFO]: Optimizer: \n" + str(optim))
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED")
            return

        crossEnt = torch.nn.BCELoss()
        if thread:
            thread.signalCanvas("\n[INFO]: Loss: \n" + str(crossEnt))

        # name of
        dataset_name = viewArgs["dir"].split('/')[2]
        prediction_data_path = 'model/Prediction/data/{}'.format(dataset_name)
        trSet_path = os.path.join(prediction_data_path, "TrainSet.npy")
        valSet_path = os.path.join(prediction_data_path, "ValSet.npy")
        trSet = ngsimDataset(trSet_path)
        valSet = ngsimDataset(valSet_path)

        trDataloader = DataLoader(trSet,
                                  batch_size=args['batch_size'],
                                  shuffle=True,
                                  num_workers=8,
                                  collate_fn=trSet.collate_fn)
        valDataloader = DataLoader(valSet,
                                   batch_size=args['batch_size'],
                                   shuffle=True,
                                   num_workers=8,
                                   collate_fn=valSet.collate_fn)

        if predAlgo == "Traphic":
            args["ours"] = True
            engine = TraphicEngine(net, optim, trDataloader, valDataloader,
                                   args, thread)
        elif predAlgo == "Social Conv":
            engine = SocialEngine(net, optim, trDataloader, valDataloader,
                                  args, thread)
        else:
            if thread:
                thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED")
        if thread:
            thread.signalCanvas("\n[INFO]: *** Training Prediction Model ***")
        engine.start()