def train_a_batch(self, engine, batch):

        self.net.train_flag = True
        epoch = engine.state.epoch

        _, _, _, _, _, _, _, fut, op_mask = batch

        fut_pred = self.netPred(batch)

        if self.cuda:
            fut = fut.cuda(self.device)
            op_mask = op_mask.cuda(self.device)

        if epoch < self.pretrainEpochs:
            if self.args["pretrain_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
            elif self.args['pretrain_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)
            else:
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
        else:
            if self.args["train_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
            elif self.args['train_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)
            else:
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)

        # if self.args['nll_only']:
        #     l = maskedNLL(fut_pred, fut, op_mask)
        # else:
        #     if epoch < self.pretrainEpochs:
        #         l = maskedMSE(fut_pred, fut, op_mask)
        #     else:
        #         l = maskedNLL(fut_pred, fut, op_mask)

        # Backprop and update weights
#        if l.item() != l.item():
#            print(l.item())
#            exit(1)
#            return 1
        self.optim.zero_grad()
        l.backward()
        self.optim.step()

        # Track average train loss:
        self.avg_trn_loss += l.item()
        self.metrics["Avg train loss"] += l.item() / 100.0

        if self.writer:
            self.writer.add_scalar(
                "{}epoch/trainingloss".format(engine.state.epoch), l.item(),
                engine.state.iteration)

        return l.item()
コード例 #2
0
ファイル: trajPredEngine.py プロジェクト: anik-jha/tnp
    def train_a_batch(self, engine, batch):

        self.net.train_flag = True
        epoch = engine.state.epoch

        _, _, _, _, _, _, _, fut, op_mask = batch

        fut_pred = self.netPred(batch)

        if self.cuda:
            fut = fut.cuda()
            op_mask = op_mask.cuda()

        if epoch < self.pretrainEpochs:
            if self.args["pretrain_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask)
            elif self.args['pretrain_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask)
            else:
                if (self.thread):
                    self.thread.signalError(
                        "[Error] Unrecognized pretrain loss, using MSE by default"
                    )
                l = maskedMSE(fut_pred, fut, op_mask)
        else:
            if self.args["train_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask)
            elif self.args['train_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask)
            else:
                self.thread.signalError(
                    "[Error] Unrecognized train loss, using NLL by default")
                l = maskedNLL(fut_pred, fut, op_mask)

        # if self.args['nll_only']:
        #     l = maskedNLL(fut_pred, fut, op_mask)
        # else:
        #     if epoch < self.pretrainEpochs:
        #         l = maskedMSE(fut_pred, fut, op_mask)
        #     else:
        #         l = maskedNLL(fut_pred, fut, op_mask)

        # Backprop and update weights
        self.optim.zero_grad()
        l.backward()
        self.optim.step()

        # Track average train loss:
        self.avg_trn_loss += l.item()
        self.metrics["Avg train loss"] += l.item() / 100.0

        return l.item()
コード例 #3
0
ファイル: trajPredEngine.py プロジェクト: anik-jha/tnp
    def eval_a_batch(self, engine, batch):
        self.net.train_flag = False

        epoch = engine.state.epoch

        _, _, _, _, _, _, _, fut, op_mask = batch
        fut_pred = self.netPred(batch)
        if self.cuda:
            fut = fut.cuda()
            op_mask = op_mask.cuda()

        # Forward pass

        if epoch < self.pretrainEpochs:
            if self.args["pretrain_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask)
            elif self.args['pretrain_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask)
            else:
                if (self.thread):
                    self.thread.signalError(
                        "[Error] Unrecognized pretrain loss, using MSE by default"
                    )
                l = maskedMSE(fut_pred, fut, op_mask)
        else:
            if self.args["train_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask)
            elif self.args['train_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask)
            else:
                self.thread.signalError(
                    "[Error] Unrecognized train loss, using NLL by default")
                l = maskedNLL(fut_pred, fut, op_mask)

        # if self.args['nll_only']:
        #     l = maskedNLL(fut_pred, fut, op_mask)
        # else:
        #     if epoch_num < pretrainEpochs:
        #         l = maskedMSE(fut_pred, fut, op_mask)
        #     else:
        #         l = maskedNLL(fut_pred, fut, op_mask)

        self.avg_val_loss += l.item()
        self.metrics["Avg val loss"] += l.item() / (self.val_batch_count *
                                                    100.0)
        self.val_batch_count += 1

        return fut_pred, fut
コード例 #4
0
    def eval_a_batch(self, engine, batch):
        self.net.train_flag = False

        epoch = engine.state.epoch

        _, _, _, _, _, _, _, fut, op_mask, _, _, _, _ = batch
        fut_pred = self.netPred(batch)
        if self.cuda:
            fut = fut.cuda(self.device)
            op_mask = op_mask.cuda(self.device)

        # Forward pass
        if self.args['nll_only']:
            l = maskedNLL(fut_pred, fut, op_mask, device=self.device)
        elif epoch < self.pretrainEpochs:
            if self.args["pretrain_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
            elif self.args['pretrain_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)
            else:
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
        else:
            if self.args["train_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
            elif self.args['train_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)
            else:
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)


        # if self.args['nll_only']:
        #     l = maskedNLL(fut_pred, fut, op_mask)
        # else:
        #     if epoch_num < pretrainEpochs:
        #         l = maskedMSE(fut_pred, fut, op_mask)
        #     else:
        #         l = maskedNLL(fut_pred, fut, op_mask)

        self.avg_val_loss += l.item()
        self.metrics["Avg val loss"] += l.item()/ (self.val_batch_count * 100.0)
        self.val_batch_count += 1

        return fut_pred, fut
コード例 #5
0
    def train_a_batch(self, engine, batch):
        # each sample: hist,fut,upper_neighbors, neighbors,lat_enc,lon_enc, bb, dsId, vehId, t
        self.net.train_flag = True
        epoch = engine.state.epoch

        _, _, _, _, _, _, _, fut, op_mask, _, _, _, _ = batch

        fut_pred = self.netPred(batch)
        
        if self.cuda:
            fut = fut.cuda(self.device)
            op_mask = op_mask.cuda(self.device)


        if self.args['nll_only']:
            l = maskedNLL(fut_pred, fut, op_mask, device=self.device)
        elif epoch < self.pretrainEpochs:
            if self.args["pretrain_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
            elif self.args['pretrain_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)
            else:
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
        else:
            if self.args["train_loss"] == 'MSE':
                l = maskedMSE(fut_pred, fut, op_mask, device=self.device)
            elif self.args['train_loss'] == 'NLL':
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)
            else:
                l = maskedNLL(fut_pred, fut, op_mask, device=self.device)

        # if self.args['nll_only']:
        #     l = maskedNLL(fut_pred, fut, op_mask)
        # else:
        #     if epoch < self.pretrainEpochs:
        #         l = maskedMSE(fut_pred, fut, op_mask)
        #     else:
        #         l = maskedNLL(fut_pred, fut, op_mask)

        # Backprop and update weights
#        if l.item() != l.item():
#            print(l.item())
#            exit(1)
#            return 1
        self.optim.zero_grad()
        l.backward()
        self.optim.step()

        # Track average train loss:
        self.avg_trn_loss += l.item()
        self.metrics["Avg train loss"] += l.item() / 100.0
           
        # if self.writer:#Ben: if you want to write down the loss, every epoch
            # print("writing...")
            # self.writer.add_scalar("{}epoch/trainingloss".format(engine.state.epoch), l.item() , engine.state.iteration)
            # def saveModel(self, engine):
            #     os.makedirs(self.args['modelLoc'], exist_ok=True)
            #     name = os.path.join(self.args['modelLoc'], self.args['name'])
            #     torch.save(self.net.state_dict(), name)
            #     print("Model saved {}.".format(name))


        return l.item()