def init(batch_size, state, input_sizes, dataset, mean, std, base, workers=10):
    # Return data_loaders
    # depending on whether the state is
    # 0: training
    # 1: fast validation by mean IoU (validation set)
    # 2: just testing (test set)
    # 3: just testing (validation set)

    # Transformations
    # ! Can't use torchvision.Transforms.Compose
    transforms_test = Compose(
        [Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
         ToTensor(),
         Normalize(mean=mean, std=std)])
    transforms_train = Compose(
        [Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
         RandomRotation(degrees=3),
         ToTensor(),
         Normalize(mean=mean, std=std)])

    if state == 0:
        data_set = StandardLaneDetectionDataset(root=base, image_set='train', transforms=transforms_train,
                                                data_set=dataset)
        data_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size,
                                                  num_workers=workers, shuffle=True)
        validation_set = StandardLaneDetectionDataset(root=base, image_set='val',
                                                      transforms=transforms_test, data_set=dataset)
        validation_loader = torch.utils.data.DataLoader(dataset=validation_set, batch_size=batch_size * 4,
                                                        num_workers=workers, shuffle=False)
        return data_loader, validation_loader

    elif state == 1 or state == 2 or state == 3:
        image_sets = ['valfast', 'test', 'val']
        data_set = StandardLaneDetectionDataset(root=base, image_set=image_sets[state - 1],
                                                transforms=transforms_test, data_set=dataset)
        data_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size,
                                                  num_workers=workers, shuffle=False)
        return data_loader
    else:
        raise ValueError
예제 #2
0
def init(batch_size,
         state,
         input_sizes,
         dataset,
         mean,
         std,
         base,
         workers=10,
         method='baseline'):
    # Return data_loaders
    # depending on whether the state is
    # 0: training
    # 1: fast validation by mean IoU (validation set)
    # 2: just testing (test set)
    # 3: just testing (validation set)

    # Transformations
    # ! Can't use torchvision.Transforms.Compose
    transforms_test = Compose([
        Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
        ToTensor(),
        Normalize(mean=mean, std=std)
    ])
    transforms_train = Compose([
        Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
        RandomRotation(degrees=3),
        ToTensor(),
        Normalize(mean=mean,
                  std=std,
                  normalize_target=True if method == 'lstr' else False)
    ])

    # Batch builder
    if method == 'lstr':
        collate_fn = dict_collate_fn
    else:
        collate_fn = None

    if state == 0:
        if method == 'lstr':
            if dataset == 'tusimple':
                data_set = TuSimple(root=base,
                                    image_set='train',
                                    transforms=transforms_train,
                                    padding_mask=True,
                                    process_points=True)
            elif dataset == 'culane':
                data_set = CULane(root=base,
                                  image_set='train',
                                  transforms=transforms_train,
                                  padding_mask=True,
                                  process_points=True)
            else:
                raise ValueError
        else:
            data_set = StandardLaneDetectionDataset(
                root=base,
                image_set='train',
                transforms=transforms_train,
                data_set=dataset)

        data_loader = torch.utils.data.DataLoader(dataset=data_set,
                                                  batch_size=batch_size,
                                                  collate_fn=collate_fn,
                                                  num_workers=workers,
                                                  shuffle=True)
        validation_set = StandardLaneDetectionDataset(
            root=base,
            image_set='val',
            transforms=transforms_test,
            data_set=dataset)
        validation_loader = torch.utils.data.DataLoader(dataset=validation_set,
                                                        batch_size=batch_size *
                                                        4,
                                                        num_workers=workers,
                                                        shuffle=False,
                                                        collate_fn=collate_fn)
        return data_loader, validation_loader

    elif state == 1 or state == 2 or state == 3:
        image_sets = ['valfast', 'test', 'val']
        if method == 'lstr':
            if dataset == 'tusimple':
                data_set = TuSimple(root=base,
                                    image_set=image_sets[state - 1],
                                    transforms=transforms_test,
                                    padding_mask=False,
                                    process_points=False)
            elif dataset == 'culane':
                data_set = CULane(root=base,
                                  image_set=image_sets[state - 1],
                                  transforms=transforms_test,
                                  padding_mask=False,
                                  process_points=False)
            else:
                raise ValueError
        else:
            data_set = StandardLaneDetectionDataset(
                root=base,
                image_set=image_sets[state - 1],
                transforms=transforms_test,
                data_set=dataset)
        data_loader = torch.utils.data.DataLoader(dataset=data_set,
                                                  batch_size=batch_size,
                                                  collate_fn=collate_fn,
                                                  num_workers=workers,
                                                  shuffle=False)
        return data_loader
    else:
        raise ValueError
예제 #3
0
def run(p_seed=0, p_epochs=150, p_kernel_size=5, p_logdir="temp"):
    # random number generator seed ------------------------------------------------#
    SEED = p_seed
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    np.random.seed(SEED)

    # kernel size of model --------------------------------------------------------#
    KERNEL_SIZE = p_kernel_size

    # number of epochs ------------------------------------------------------------#
    NUM_EPOCHS = p_epochs

    # file names ------------------------------------------------------------------#
    if not os.path.exists("../logs/%s" % p_logdir):
        os.makedirs("../logs/%s" % p_logdir)
    OUTPUT_FILE = str("../logs/%s/log%03d.out" % (p_logdir, SEED))
    MODEL_FILE = str("../logs/%s/model%03d.pth" % (p_logdir, SEED))

    # enable GPU usage ------------------------------------------------------------#
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda == False:
        print("WARNING: CPU will be used for training.")
        exit(0)

    # data augmentation methods ---------------------------------------------------#
    transform = transforms.Compose([
        RandomRotation(20, seed=SEED),
        transforms.RandomAffine(0, translate=(0.2, 0.2)),
    ])

    # data loader -----------------------------------------------------------------#
    train_dataset = MnistDataset(training=True, transform=transform)
    test_dataset = MnistDataset(training=False, transform=None)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=120,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=100,
                                              shuffle=False)

    # model selection -------------------------------------------------------------#
    if (KERNEL_SIZE == 3):
        model = ModelM3().to(device)
    elif (KERNEL_SIZE == 5):
        model = ModelM5().to(device)
    elif (KERNEL_SIZE == 7):
        model = ModelM7().to(device)

    summary(model, (1, 28, 28))

    # hyperparameter selection ----------------------------------------------------#
    ema = EMA(model, decay=0.999)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                          gamma=0.98)

    # delete result file ----------------------------------------------------------#
    f = open(OUTPUT_FILE, 'w')
    f.close()

    # global variables ------------------------------------------------------------#
    g_step = 0
    max_correct = 0

    # training and evaluation loop ------------------------------------------------#
    for epoch in range(NUM_EPOCHS):
        #--------------------------------------------------------------------------#
        # train process                                                            #
        #--------------------------------------------------------------------------#
        model.train()
        train_loss = 0
        train_corr = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            train_pred = output.argmax(dim=1, keepdim=True)
            train_corr += train_pred.eq(
                target.view_as(train_pred)).sum().item()
            train_loss += F.nll_loss(output, target, reduction='sum').item()
            loss.backward()
            optimizer.step()
            g_step += 1
            ema(model, g_step)
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\tLoss: {:.6f}'.
                      format(epoch, batch_idx * len(data),
                             len(train_loader.dataset),
                             100. * batch_idx / len(train_loader),
                             loss.item()))
        train_loss /= len(train_loader.dataset)
        train_accuracy = 100 * train_corr / len(train_loader.dataset)

        #--------------------------------------------------------------------------#
        # test process                                                             #
        #--------------------------------------------------------------------------#
        model.eval()
        ema.assign(model)
        test_loss = 0
        correct = 0
        total_pred = np.zeros(0)
        total_target = np.zeros(0)
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()
                pred = output.argmax(dim=1, keepdim=True)
                total_pred = np.append(total_pred, pred.cpu().numpy())
                total_target = np.append(total_target, target.cpu().numpy())
                correct += pred.eq(target.view_as(pred)).sum().item()
            if (max_correct < correct):
                torch.save(model.state_dict(), MODEL_FILE)
                max_correct = correct
                print("Best accuracy! correct images: %5d" % correct)
        ema.resume(model)

        #--------------------------------------------------------------------------#
        # output                                                                   #
        #--------------------------------------------------------------------------#
        test_loss /= len(test_loader.dataset)
        test_accuracy = 100 * correct / len(test_loader.dataset)
        best_test_accuracy = 100 * max_correct / len(test_loader.dataset)
        print(
            '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%) (best: {:.2f}%)\n'
            .format(test_loss, correct, len(test_loader.dataset),
                    test_accuracy, best_test_accuracy))

        f = open(OUTPUT_FILE, 'a')
        f.write(" %3d %12.6f %9.3f %12.6f %9.3f %9.3f\n" %
                (epoch, train_loss, train_accuracy, test_loss, test_accuracy,
                 best_test_accuracy))
        f.close()

        #--------------------------------------------------------------------------#
        # update learning rate scheduler                                           #
        #--------------------------------------------------------------------------#
        lr_scheduler.step()
예제 #4
0
                        best_loss = loss
                        best_params = [p.detach().cpu() for p in model.parameters()]
                    print("global step: %d (epoch: %d, step: %d), loss: %f %s" %
                          (global_step, epoch, inum, loss.item(), stats))

    except KeyboardInterrupt:
        pass

    if best_params is not None:
        print("\nLoading best params")
        model.parameters = best_params
        print("Params loaded")

    model.to('cpu')
    model.eval()
    model.save_to_drive('vae_trained')

if __name__ == '__main__':
    train_dataset = ModelnetDataset(transform=RandomRotation())
    test_dataset = ModelnetDataset(transform=None)

    train_loader = DataLoader(train_dataset, batch_size=24,
                            shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=24,
                            shuffle=True, num_workers=1)

    model = VAE(ENCODER_HIDDEN, [LATENT, DECODER_HIDDEN, 3*OUT_POINTS])
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    train_model(model, optimizer, train_loader)
    def _init_fn(self):
        # transform_list = [Select(['image', 'bb', 'keypoints'])]
        # transform_list.append(RandomRescaleBB(0.9, 1.4))
        # transform_list.append(RandomFlipLR())
        transform_list = []
        transform_list.append(RandomBlur())
        #transform_list.append(RandomGrayscale())
        transform_list.append(
            ColorJitter(brightness=self.options.jitter,
                        contrast=self.options.jitter,
                        saturation=self.options.jitter,
                        hue=self.options.jitter / 4))
        if self.options.degrees > 0:
            transform_list.append(RandomRotation(degrees=self.options.degrees))
        if self.options.max_scale > 1:
            transform_list.append(RandomRescaleBB(1.0, self.options.max_scale))
        transform_list.append(
            CropAndResize(out_size=(self.options.crop_size,
                                    self.options.crop_size)))
        transform_list.append(
            LocsToHeatmaps(out_size=(self.options.heatmap_size,
                                     self.options.heatmap_size)))
        transform_list.append(ToTensor())
        transform_list.append(Normalize())

        test_transform_list = []
        if self.options.max_scale > 1:
            test_transform_list.append(
                RandomRescaleBB(1.0, self.options.max_scale))
        test_transform_list.append(
            CropAndResize(out_size=(self.options.crop_size,
                                    self.options.crop_size)))
        test_transform_list.append(
            LocsToHeatmaps(out_size=(self.options.heatmap_size,
                                     self.options.heatmap_size)))
        test_transform_list.append(ToTensor())
        test_transform_list.append(Normalize())

        self.train_ds = RctaDataset(
            root_dir=self.options.dataset_dir,
            is_train=True,
            transform=transforms.Compose(transform_list))
        # print("Keypoints in trainer:", self.train_ds.keypoints[74*6])
        # print("Bounding boxes:", self.train_ds.bounding_boxes[74*6])
        self.test_ds = RctaDataset(
            root_dir=self.options.dataset_dir,
            is_train=False,
            transform=transforms.Compose(test_transform_list))
        self.model = StackedHourglass(self.options.num_keypoints).to(
            self.device)
        print('Total number of model parameters:',
              self.model.num_trainable_parameters())

        # create optimizer
        # if self.options.optimizer == 'sgd':
        #       self.optimizer = torch.optim.SGD(params=self.model.parameters(), lr=self.options.lr, momentum=self.options.sgd_momentum, weight_decay=self.options.wd)
        # elif self.options.optimizer == 'rmsprop':
        self.optimizer = torch.optim.RMSprop(params=self.model.parameters(),
                                             lr=self.options.lr,
                                             momentum=0,
                                             weight_decay=self.options.wd)
        # else:
        #     self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=self.options.lr, betas=(self.options.adam_beta1, 0.999), weight_decay=self.options.wd)

        # pack all models and optimizers in dictionaries to interact with the checkpoint saver
        self.models_dict = {'stacked_hg': self.model}
        self.optimizers_dict = {'optimizer': self.optimizer}

        self.criterion = nn.MSELoss(size_average=True).to(self.device)
        self.pose = Pose2DEval(detection_thresh=self.options.detection_thresh,
                               dist_thresh=self.options.dist_thresh)