Ejemplo n.º 1
0
def main(hidden_layers, hidden_size, batch_size, learning_rate):
    log.debug("hidden_layers=%d, hidden_size=%d, batch_size=%d, learning_rate=%f" %(hidden_layers, hidden_size, batch_size, learning_rate))
    model = MultiLayerNet(input_size=9, hidden_size=hidden_size, output_size=1, layers=hidden_layers)
    #loss_fn = nn.MSELoss(reduction='sum')
    criterion = nn.BCEWithLogitsLoss(reduction='sum')
    #learning_rate = 1e-4

    #optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=0.0005)
    epoch = 1000
    #batch_size = 1000
    train_dataset = Dataset('dataset/train2.json')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=4, pin_memory=True)
    val_dataset = Dataset('dataset/test2.json')
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=val_dataset.collate_fn, num_workers=4, pin_memory=True)

    model = model.to(device)
    loss_data = []
    for t in range(epoch):
        if t == 300:
            optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4, weight_decay=0.0005)
        elif t == 600:
            optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-5, weight_decay=0.0005)
        sum_loss = train(train_loader, model, criterion, optimizer)
        mean_loss = sum_loss / train_dataset.__len__()
        loss_data.append(mean_loss)
        if t % 20 == 0:
            print("Epoch[%d] loss:%f" %(t, mean_loss))
            log.debug("Epoch[%d] loss:%f" %(t, mean_loss))
            sum_loss = validate(val_loader, model, criterion)
            mean_loss = sum_loss / val_dataset.__len__()
            print("Validate loss:%f" %(mean_loss))
            log.debug("Vaildate loss:%f" %(mean_loss))



    log.debug("training finish.")
    log.debug(" \n \n \n")
    save_model = "weights/model_hl%d_hs%d_bs%d_bn.pth" %(hidden_layers, hidden_size, batch_size)
    torch.save(model.state_dict(), save_model)
    data = {
        'hidden_layers':hidden_layers,
        'hidden_size':hidden_size,
        'batch_size':batch_size,
        'learning_rate':learning_rate,
        'loss':loss_data
    }
    return data
Ejemplo n.º 2
0
def get_pdfs(dataset: data.Dataset):
    pdf_x, pdf_t, pdf_xt = [Counter(), Counter(), Counter()]
    n_samples = dataset.__len__()

    for i in range(n_samples):
        (x, y) = dataset.__getitem__(i)
        print(x, y)
Ejemplo n.º 3
0
def make_weights_for_balanced_classes(dataset: Dataset):
    weight_per_class = dict()
    N = dataset.__len__()
    for i in range(len(np.unique(dataset.labels))):
        cnt_element_in_class = len([j for j in dataset.labels if i == j])
        weight_per_class[i] = cnt_element_in_class / N

    weight = list(map(lambda x: weight_per_class[x], dataset.labels))
    return weight
    def __init__(self, batch_size, dataset_name, method, num_workers):
        self._build_model()
        self.dataset_name = dataset_name
        self.pretrained_model = '/data/pose_estimation'
        self.mse_all_img = 0
        self.per_image = 0

        dataset = Dataset(method=method, dataset_name=dataset_name)
        self.save_img_root = dataset.image_save_root
        self.datalen = dataset.__len__()
        self.root = dataset.root
        self.error_check_x = 0
        self.error_check_y = 0
        self.error_check_z = 0


        self.dataloader = DataLoader(dataset=dataset, batch_size=batch_size,  shuffle=False)

        # Load of pretrained_weight file #100epoch32bat/
        weight_PATH = self.pretrained_model + '/result_model_depth/obman/estimation/' + '2021_1_6_70_model.pth' #92
        self.poseNet.load_state_dict(torch.load(weight_PATH), strict=False)

        print("Testing...")
    def __init__(self, epochs, batch_size, lr, dataset_name, method, num_workers):
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = lr

        self.dataset_name = dataset_name
        self.method = method
        self._build_model()

        self.cost = nn.MSELoss()
        self.optimizer = optim.Adam(self.poseNet.parameters(), lr=self.learning_rate)

        dataset = Dataset(method=method, dataset_name=dataset_name)
        self.datalen = dataset.__len__()
        self.root = dataset.root
        self.dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)

        # Load of pretrained_weight file
        weight_PATH = './result_model_depth/2021_1_4_0_model.pth'
        self.poseNet.load_state_dict(torch.load(weight_PATH), strict=False)
        self.date = "_".join([str(datetime.today().year), str(datetime.today().month), str(datetime.today().day)])
        # self.date = '201205'

        print("Training...")
Ejemplo n.º 6
0
def test(args):
    model = LSTM(args.batch_size,
                 args.feature_dim,
                 args.hidden_size,
                 args.num_layers,
                 args.loc_dim).to(args.device)
    model = model.eval()

    if args.resume:
        nu.load_checkpoint(model, 
                        args.ckpt_path.format(args.session, args.set, args.start_epoch),
                        is_test=True)

    if args.path.endswith('_bdd_roipool_output.pkl'):
        paths = [args.path]
        dsize = len(paths)
        n_parts = 1
    else:
        paths = sorted(
            [os.path.join(args.path, n) for n in os.listdir(args.path)
             if n.endswith('_bdd_roipool_output.pkl')])
        dsize = len(paths) // args.n_parts
        n_parts = args.n_parts

    print("Total {} sequences separate into {} parts".format(dsize, n_parts))

    dataset = Dataset(args, paths[:dsize])

    print("Number of trajectories to test: {}".format(dataset.__len__()))

    # Data loading code
    test_loader = DataLoader(
        dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    # Start iterations
    losses = tu.AverageMeter()
    pred_losses = tu.AverageMeter()
    refine_losses = tu.AverageMeter()
    lin_losses = tu.AverageMeter()
    losses_kf = tu.AverageMeter()
    pred_losses_kf = tu.AverageMeter()
    refine_losses_kf = tu.AverageMeter()
    lin_losses_kf = tu.AverageMeter()

    epoch = args.start_epoch

    for part in range(n_parts):
        for iters, traj_out in enumerate(iter(test_loader)):

            # Initial
            loc_gt = traj_out['loc_gt'].to(args.device)[:, 1:]
            loc_obs = traj_out['loc_pd'].to(args.device)
            cam_loc = traj_out['cam_loc'].to(args.device)
            valid_mask = traj_out['valid_mask'].to(args.device)

            loc_preds = []
            loc_refines = []
            loc_preds_kf = []
            loc_refines_kf = []

            # Also, we need to clear out the hidden state of the LSTM,
            # detaching it from its history on the last instance.
            hidden_predict = model.init_hidden(args.device)  # None
            hidden_refine = model.init_hidden(args.device)  # None


            # Generate a history of location
            vel_history = loc_obs.new_zeros(args.num_seq, loc_obs.shape[0], 3)
            for i in range(valid_mask.shape[1]):

                # Force input update or use predicted
                if i == 0:
                    loc_refine = loc_obs[:, i]
                    trk = KalmanBox3dTracker(loc_refine.cpu().detach().numpy())
                # Update history
                vel_history = torch.cat([vel_history[1:], (loc_obs[:, i] - loc_refine).unsqueeze(0)], dim=0)
                loc_pred, hidden_predict = model.predict(vel_history,
                                                         loc_refine,
                                                         hidden_predict)
                loc_refine, hidden_refine = model.refine(loc_pred,
                                                         loc_obs[:,
                                                         i + 1],
                                                         hidden_refine)

                loc_pred_kf = trk.predict().squeeze()
                trk.update(loc_obs[:, i+1].cpu().detach().numpy())
                loc_refine_kf = trk.get_state()[:3]

                # Predict residual of depth
                loc_preds.append(loc_pred)
                loc_refines.append(loc_refine)
                loc_preds_kf.append(loc_pred_kf)
                loc_refines_kf.append(loc_refine_kf)

            loc_preds = torch.cat(loc_preds, dim=1).view(valid_mask.shape[0],
                                                         -1, 3)
            loc_refines = torch.cat(loc_refines, dim=1).view(
                valid_mask.shape[0], -1, 3)
            loc_preds_kf = valid_mask.new(loc_preds_kf).view(valid_mask.shape[0],
                                                         -1, 3)
            loc_refines_kf = valid_mask.new(loc_refines_kf).view(
                valid_mask.shape[0], -1, 3)

            loc_preds = loc_preds * valid_mask.unsqueeze(2)
            loc_refines = loc_refines * valid_mask.unsqueeze(2)
            loc_preds_kf = loc_preds_kf * valid_mask.unsqueeze(2)
            loc_refines_kf = loc_refines_kf * valid_mask.unsqueeze(2)

            # Cost functions
            pred_loss = F.l1_loss(loc_preds, loc_gt * valid_mask.unsqueeze(2),
                                  reduction='sum') / torch.sum(valid_mask)
            refine_loss = F.l1_loss(loc_refines,
                                    loc_gt * valid_mask.unsqueeze(2),
                                    reduction='sum') / torch.sum(
                valid_mask)
            pred_loss_kf = F.l1_loss(loc_preds_kf, loc_gt * valid_mask.unsqueeze(2),
                                  reduction='sum') / torch.sum(valid_mask)
            refine_loss_kf = F.l1_loss(loc_refines_kf,
                                    loc_gt * valid_mask.unsqueeze(2),
                                    reduction='sum') / torch.sum(
                valid_mask)
            dep_loss = pred_loss + refine_loss
            linear_loss = nu.linear_motion_loss(loc_preds, valid_mask)
            linear_loss += nu.linear_motion_loss(loc_refines, valid_mask)
            dep_loss_kf = pred_loss_kf + refine_loss_kf
            linear_loss_kf = nu.linear_motion_loss(loc_preds_kf, valid_mask)
            linear_loss_kf += nu.linear_motion_loss(loc_refines_kf, valid_mask)

            loss = (args.depth_weight * dep_loss +
                    (1.0 - args.depth_weight) * linear_loss)  # /
            # torch.sum(valid_mask)
            loss_kf = (args.depth_weight * dep_loss_kf +
                    (1.0 - args.depth_weight) * linear_loss_kf)  # /
            # torch.sum(valid_mask)

            # Updates
            losses.update(loss.data.cpu().numpy().item(),
                          int(torch.sum(valid_mask)))
            pred_losses.update(pred_loss.data.cpu().numpy().item(),
                               int(torch.sum(valid_mask)))
            refine_losses.update(refine_loss.data.cpu().numpy().item(),
                                 int(torch.sum(valid_mask)))
            lin_losses.update(linear_loss.data.cpu().numpy().item(),
                              int(torch.sum(valid_mask)))
            # Updates
            losses_kf.update(loss_kf.data.cpu().numpy().item(),
                          int(torch.sum(valid_mask)))
            pred_losses_kf.update(pred_loss_kf.data.cpu().numpy().item(),
                               int(torch.sum(valid_mask)))
            refine_losses_kf.update(refine_loss_kf.data.cpu().numpy().item(),
                                 int(torch.sum(valid_mask)))
            lin_losses_kf.update(linear_loss_kf.data.cpu().numpy().item(),
                              int(torch.sum(valid_mask)))

            # Verbose
            if iters % args.show_freq == 0 and iters != 0:
                print('[{NAME} - {SESS}] Epoch: [{EP}][{IT}/{TO}] '
                      'Loss {loss.val:.4f} ({loss.avg:.3f}) '
                      'P-Loss {pred.val:.2f} ({pred.avg:.2f}) '
                      'R-Loss {refine.val:.2f} ({refine.avg:.2f}) '
                      'S-Loss {smooth.val:.2f} ({smooth.avg:.2f}) \n'
                      'Loss {loss_kf.val:.4f} ({loss_kf.avg:.3f}) '
                      'P-Loss {pred_kf.val:.2f} ({pred_kf.avg:.2f}) '
                      'R-Loss {refine_kf.val:.2f} ({refine_kf.avg:.2f}) '
                      'S-Loss {smooth_kf.val:.2f} ({smooth_kf.avg:.2f}) '.format(
                    NAME=args.set.upper(),
                    SESS=args.session, EP=epoch,
                    IT=iters, TO=len(test_loader),
                    loss=losses, pred=pred_losses,
                    refine=refine_losses, smooth=lin_losses,
                    loss_kf=losses_kf, pred_kf=pred_losses_kf,
                    refine_kf=refine_losses_kf, smooth_kf=lin_losses_kf))
                print("PD: {pd} OB: {obs} RF: {ref} GT: {gt} \n"
                      "PDKF: {pdkf} OBKF: {obs} RFKF: {refkf} GT: {gt}".format(
                    pd=loc_preds.cpu().data.numpy().astype(int)[0, 0],
                    obs=loc_obs.cpu().data.numpy().astype(int)[0, 1],
                    ref=loc_refines.cpu().data.numpy().astype(int)[0, 0],
                    pdkf=loc_preds_kf.cpu().data.numpy().astype(int)[0, 0],
                    refkf=loc_refines_kf.cpu().data.numpy().astype(int)[0, 0],
                    gt=loc_gt.cpu().data.numpy().astype(int)[0, 0]))

                if args.is_plot:
                    plot_3D('{}_{}'.format(epoch, iters), args.session,
                            cam_loc.cpu().data.numpy()[0],
                            loc_gt.cpu().data.numpy()[0],
                            loc_preds.cpu().data.numpy()[0],
                            loc_refines.cpu().data.numpy()[0])
                    plot_3D('{}_{}_kf'.format(epoch, iters), args.session,
                            cam_loc.cpu().data.numpy()[0],
                            loc_gt.cpu().data.numpy()[0],
                            loc_preds_kf.cpu().data.numpy()[0],
                            loc_refines_kf.cpu().data.numpy()[0])

            print('[{NAME} - {SESS}] Epoch: [{EP}][{IT}/{TO}] '
                  'Loss {loss.val:.4f} ({loss.avg:.3f}) '
                  'P-Loss {pred.val:.2f} ({pred.avg:.2f}) '
                  'R-Loss {refine.val:.2f} ({refine.avg:.2f}) '
                  'S-Loss {smooth.val:.2f} ({smooth.avg:.2f}) \n'
                  'Loss {loss_kf.val:.4f} ({loss_kf.avg:.3f}) '
                  'P-Loss {pred_kf.val:.2f} ({pred_kf.avg:.2f}) '
                  'R-Loss {refine_kf.val:.2f} ({refine_kf.avg:.2f}) '
                  'S-Loss {smooth_kf.val:.2f} ({smooth_kf.avg:.2f}) '.format(
                NAME=args.set.upper(),
                SESS=args.session, EP=epoch,
                IT=iters, TO=len(test_loader),
                loss=losses, pred=pred_losses,
                refine=refine_losses, smooth=lin_losses,
                loss_kf=losses_kf, pred_kf=pred_losses_kf,
                refine_kf=refine_losses_kf, smooth_kf=lin_losses_kf))
            print("PD: {pd} OB: {obs} RF: {ref} GT: {gt} \n"
                  "PDKF: {pdkf} OBKF: {obs} RFKF: {refkf} GT: {gt}".format(
                pd=loc_preds.cpu().data.numpy().astype(int)[0, 0],
                obs=loc_obs.cpu().data.numpy().astype(int)[0, 1],
                ref=loc_refines.cpu().data.numpy().astype(int)[0, 0],
                pdkf=loc_preds_kf.cpu().data.numpy().astype(int)[0, 0],
                refkf=loc_refines_kf.cpu().data.numpy().astype(int)[0, 0],
                gt=loc_gt.cpu().data.numpy().astype(int)[0, 0]))

        if n_parts != 1:
            dataset = Dataset(args, paths[part * dsize:part * dsize + dsize])

            print(
                "Number of trajectories to test: {}".format(dataset.__len__()))

            # Data loading code
            test_loader = DataLoader(
                dataset,
                batch_size=args.batch_size, shuffle=True,
                num_workers=args.workers, pin_memory=True, drop_last=True)
Ejemplo n.º 7
0
def train(args):
    model = LSTM(args.batch_size,
                 args.feature_dim,
                 args.hidden_size,
                 args.num_layers,
                 args.loc_dim).to(args.device)
    model = model.train()

    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=args.weight_decay,
        amsgrad=True
    )

    if args.resume:
        nu.load_checkpoint(model, 
                        args.ckpt_path.format(args.session, args.set, args.start_epoch),
                        optimizer=optimizer)

    if args.path.endswith('_bdd_roipool_output.pkl'):
        paths = [args.path]
        dsize = len(paths)
        n_parts = 1
    else:
        paths = sorted(
            [os.path.join(args.path, n) for n in os.listdir(args.path)
             if n.endswith('_bdd_roipool_output.pkl')])
        dsize = len(paths) // args.n_parts
        n_parts = args.n_parts

    print("Total {} sequences separate into {} parts".format(dsize, n_parts))

    dataset = Dataset(args, paths[:dsize])

    print("Number of trajectories to train: {}".format(dataset.__len__()))

    # Data loading code
    train_loader = DataLoader(
        dataset,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    # Start iterations
    for epoch in range(args.start_epoch, args.num_epochs + 1):
        losses = tu.AverageMeter()
        pred_losses = tu.AverageMeter()
        refine_losses = tu.AverageMeter()
        lin_losses = tu.AverageMeter()
        nu.adjust_learning_rate(args, optimizer, epoch)

        for part in range(n_parts):

            for iters, traj_out in enumerate(iter(train_loader)):

                # Initial
                loc_gt = traj_out['loc_gt'].to(args.device)[:, 1:]
                loc_obs = 0.05 *traj_out['loc_pd'].to(args.device) + \
                            0.95 * traj_out['loc_gt'].to(args.device)
                cam_loc = traj_out['cam_loc'].to(args.device)
                valid_mask = traj_out['valid_mask'].to(args.device)

                loc_preds = []
                loc_refines = []

                # Also, we need to clear out the hidden state of the LSTM,
                # detaching it from its history on the last instance.
                hidden_predict = model.init_hidden(args.device)  # None
                hidden_refine = model.init_hidden(args.device)  # None


                # Generate a history of location
                vel_history = loc_obs.new_zeros(args.num_seq, loc_obs.shape[0], 3)
                for i in range(valid_mask.shape[1]):

                    # Force input update or use predicted
                    if i == 0:
                        loc_refine = loc_obs[:, i]
                    # Update history
                    loc_pred, hidden_predict = model.predict(vel_history,
                                                             loc_refine,
                                                             hidden_predict)
                    vel_history = torch.cat([vel_history[1:], (loc_obs[:, i+1] - loc_pred).unsqueeze(0)], dim=0)
                    loc_refine, hidden_refine = model.refine(loc_pred,
                                                             loc_obs[:,
                                                             i + 1],
                                                             hidden_refine)
                    print(vel_history[:, 0, :].cpu().detach().numpy(), 
                        loc_pred[0, :].cpu().detach().numpy(), 
                        loc_refine[0, :].cpu().detach().numpy(), 
                        loc_gt[0, i, :].cpu().detach().numpy())
                    # Predict residual of depth
                    loc_preds.append(loc_pred)
                    loc_refines.append(loc_refine)

                loc_preds = torch.cat(loc_preds, dim=1).view(
                    valid_mask.shape[0], -1, 3)
                loc_refines = torch.cat(loc_refines, dim=1).view(
                    valid_mask.shape[0], -1, 3)

                loc_preds = loc_preds * valid_mask.unsqueeze(2)
                loc_refines = loc_refines * valid_mask.unsqueeze(2)

                # Cost functions
                pred_loss = F.l1_loss(loc_preds,
                                      loc_gt * valid_mask.unsqueeze(2),
                                      reduction='sum') / torch.sum(
                    valid_mask)
                refine_loss = F.l1_loss(loc_refines,
                                        loc_gt * valid_mask.unsqueeze(2),
                                        reduction='sum') / torch.sum(
                    valid_mask)
                dep_loss = (pred_loss + refine_loss)
                linear_loss = nu.linear_motion_loss(loc_preds, valid_mask)
                linear_loss += nu.linear_motion_loss(loc_refines, valid_mask)

                loss = (args.depth_weight * dep_loss +
                        (1.0 - args.depth_weight) * linear_loss)
                # / torch.sum(valid_mask)

                # Clear the states of model parameters each time
                optimizer.zero_grad()

                # BP loss
                loss.backward()

                # Clip if the gradients explode
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

                optimizer.step()

                # Updates
                losses.update(loss.data.cpu().numpy().item(),
                              int(torch.sum(valid_mask)))
                pred_losses.update(pred_loss.data.cpu().numpy().item(),
                                   int(torch.sum(valid_mask)))
                refine_losses.update(refine_loss.data.cpu().numpy().item(),
                                     int(torch.sum(valid_mask)))
                lin_losses.update(linear_loss.data.cpu().numpy().item(),
                                  int(torch.sum(valid_mask)))

                # Verbose
                if iters % args.show_freq == 0 and iters != 0:
                    print('[{NAME} - {SESS}] Epoch: [{EP}][{IT}/{TO}] '
                          'Loss {loss.val:.4f} ({loss.avg:.3f}) '
                          'P-Loss {pred.val:.2f} ({pred.avg:.2f}) '
                          'R-Loss {refine.val:.2f} ({refine.avg:.2f}) '
                          'S-Loss {smooth.val:.2f} ({smooth.avg:.2f}) '.format(
                        NAME=args.set.upper(),
                        SESS=args.session, EP=epoch,
                        IT=iters, TO=len(train_loader),
                        loss=losses, pred=pred_losses,
                        refine=refine_losses, smooth=lin_losses))
                    print("PD: {pd} OB: {obs} RF: {ref} GT: {gt}".format(
                        pd=loc_preds.cpu().data.numpy().astype(int)[0, 2],
                        obs=loc_obs.cpu().data.numpy().astype(int)[0, 3],
                        ref=loc_refines.cpu().data.numpy().astype(int)[0, 2],
                        gt=loc_gt.cpu().data.numpy().astype(int)[0, 2]))

                    if args.is_plot:
                        plot_3D('{}_{}'.format(epoch, iters), args.session,
                                cam_loc.cpu().data.numpy()[0],
                                loc_gt.cpu().data.numpy()[0],
                                loc_preds.cpu().data.numpy()[0],
                                loc_refines.cpu().data.numpy()[0])

            # Save
            if epoch != args.start_epoch:
                torch.save({'epoch': epoch,
                            'state_dict': model.state_dict(),
                            'session': args.session},
                           args.ckpt_path.format(args.session, args.set, epoch))

            print(
                "Epoch [{}] Loss: {:.3f} P-Loss: {:.3f} R-Loss: {:.3f} "
                "S-Loss: {:.3f} ".format(
                    epoch,
                    losses.avg,
                    pred_losses.avg,
                    refine_losses.avg,
                    lin_losses.avg))

            if n_parts != 1:
                dataset = Dataset(args,
                                  paths[part * dsize:part * dsize + dsize])

                print("Number of trajectories to train: {}".format(
                    dataset.__len__()))

                # Data loading code
                train_loader = DataLoader(
                    dataset,
                    batch_size=args.batch_size, shuffle=True,
                    num_workers=args.workers, pin_memory=True, drop_last=True)
Ejemplo n.º 8
0
    def __getitem__(self, index):

        img = cv2.imread(self.filepath + '/' + self.image_ID[index])
        img = cv2.resize(img, (66, 200), interpolation=cv2.INTER_AREA)

        steer_labels = self.Steer_Ang[index]
        steer_rad_labels = (steer_labels * np.pi) / 180
        steer_rad_labels = torch.tensor(steer_rad_labels)

        return (torch.from_numpy(img).float(), steer_labels)


# Initializing and Splitting the dataset
data = Dataset()
total_len = data.__len__()
train_split = int(0.44708 * total_len)
valid_split = int(0.3327 * total_len)
test_split = total_len - train_split - valid_split
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
    data, (train_split, valid_split, test_split))

print(total_len)
print(train_split)
print(valid_split)
print(test_split)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=200,
                                           shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
Ejemplo n.º 9
0
 def _generate_bootstrap_dataset(self, dataset: Dataset) -> Subset:
     dataset_size = dataset.__len__()
     sampled_indices = choices(range(dataset_size), k=dataset_size)
     return Subset(dataset, sampled_indices)