def train(self, dsId=None): if dsId: self.args['dsId'] = dsId self.net.train_flag = True self.net.train() if self.args["cuda"]: self.net.cuda(self.args['device']) if self.args["optim"] == "Adam": optim = torch.optim.Adam(self.net.parameters(),lr=self.args['learning_rate']) elif self.args["optim"] == "SGD": optim = torch.optim.SGD(self.net.parameters(),lr=self.args['learning_rate']) elif self.args["optim"] == "AdamW": optim = torch.optim.AdamW(self.net.parameters(),lr=self.args['learning_rate']) elif self.args["optim"] == "SparseAdam": optim = torch.optim.SparseAdam(self.net.parameters(),lr=self.args['learning_rate']) elif self.args["optim"] == "Adamax": optim = torch.optim.Adamax(self.net.parameters(),lr=self.args['learning_rate']) elif self.args["optim"] == "ASGD": optim = torch.optim.ASGD(self.net.parameters(),lr=self.args['learning_rate']) elif self.args["optim"] == "Rprop": optim = torch.optim.Rprop(self.net.parameters(),lr=self.args['learning_rate']) elif self.args["optim"] == "RMSprop": optim = torch.optim.RMSprop(self.net.parameters(),lr=self.args['learning_rate']) elif self.args["optim"] == "LBFGS": optim = torch.optim.LBFGS(self.net.parameters(),lr=self.args['learning_rate']) else: print("undefined optimizer.") return crossEnt = torch.nn.BCELoss() print('loading data in {}...'.format(self.args['dsId'])) trSet_path = os.path.join(self.args["dir"], "trainSet") valSet_path = os.path.join(self.args["dir"], "valSet") trSet = ngsimDataset(trSet_path, self.args["dir"], self.args["raw_dir"], 'train', self.args['dsId'], t_h=self.args['in_length'], t_f=self.args['out_length']) valSet = ngsimDataset(valSet_path, self.args["dir"], self.args["raw_dir"], 'val', self.args['dsId'], t_h=self.args['in_length'], t_f=self.args['out_length']) trDataloader = DataLoader(trSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=trSet.collate_fn) # trDataloader = DataLoader(valSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=valSet.collate_fn) valDataloader = DataLoader(valSet,batch_size=self.args['batch_size'],shuffle=True,num_workers=8,collate_fn=valSet.collate_fn) print('start training {}...'.format(self.args["predAlgo"])) if self.args["predAlgo"] == "Traphic": engine = TraphicEngine(self.net, optim, trDataloader, valDataloader, self.args) else: engine = SocialEngine(self.net, optim, trDataloader, valDataloader, self.args) engine.start()
def train(self, viewArgs, thread=None): if thread: thread.signalCanvas("\n[INFO] Training started...") args = self.getPredArgs(viewArgs) args['eval'] = False predAlgo = viewArgs["predAlgo"] optimSelection = viewArgs["optim"] if predAlgo == "Traphic": if thread: thread.signalCanvas("\n[INFO]: Using TRAPHIC model") args["ours"] = True args['train_flag'] = True net = traphicNet(args) elif predAlgo == "Social GAN": sgan_args.num_epochs = int(args["pretrainEpochs"]) + int( args["trainEpochs"]) sgan_args.batch_size = args["batch_size"] sgan_args.dropout = args['dropout_prob'] sgan_args.g_learning_rate = args["learning_rate"] sgan_args.g_learning_rate = args["learning_rate"] if thread: thread.signalCanvas("\n[INFO]: Using Sgan model") thread.signalCanvas( "\n[INFO]: *** Training Prediction Model ***") sganTrain(sgan_args, thread) return elif predAlgo == "Social-LSTM": print(predAlgo) elif predAlgo == "Social Conv": if thread: thread.signalCanvas( "\n[INFO]: Using Convolutional Social Pooling") args['train_flag'] = True net = highwayNet(args) if args["cuda"]: if thread: thread.signalCanvas("\n[INFO]: Using CUDA") net.cuda() if optimSelection == "Adam": optim = torch.optim.Adam(net.parameters(), lr=args['learning_rate']) if thread: thread.signalCanvas("\n[INFO]: Optimizer: \n" + str(optim)) else: if thread: thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED") return crossEnt = torch.nn.BCELoss() if thread: thread.signalCanvas("\n[INFO]: Loss: \n" + str(crossEnt)) # name of dataset_name = viewArgs["dir"].split('/')[2] prediction_data_path = 'model/Prediction/data/{}'.format(dataset_name) trSet_path = os.path.join(prediction_data_path, "TrainSet.npy") valSet_path = os.path.join(prediction_data_path, "ValSet.npy") trSet = ngsimDataset(trSet_path) valSet = ngsimDataset(valSet_path) trDataloader = DataLoader(trSet, batch_size=args['batch_size'], shuffle=True, num_workers=8, collate_fn=trSet.collate_fn) valDataloader = DataLoader(valSet, batch_size=args['batch_size'], shuffle=True, num_workers=8, collate_fn=valSet.collate_fn) if predAlgo == "Traphic": args["ours"] = True engine = TraphicEngine(net, optim, trDataloader, valDataloader, args, thread) elif predAlgo == "Social Conv": engine = SocialEngine(net, optim, trDataloader, valDataloader, args, thread) else: if thread: thread.signalCanvas("\n[INFO]: NOT YET IMPLEMENTED") if thread: thread.signalCanvas("\n[INFO]: *** Training Prediction Model ***") engine.start()