예제 #1
0
파일: dataset.py 프로젝트: kawarasoba/ukiyo
def main():

    # augmentation
    transform_aug = Compose([
        aug.HueSaturationValue(),
        aug.RandomBrightnessContrast(),
        aug.CLAHE(),
        aug.JpegCompression(),
        aug.GaussNoise(),
        aug.MedianBlur(),
        aug.ElasticTransform(),
        aug.HorizontalFlip(),
        aug.Rotate(),
        aug.CoarseDropout(),
        aug.RandomSizedCrop()
    ],
                            p=1)
    # transform for output
    transform = Compose([
        Resize(cons.IMAGE_SIZE, cons.IMAGE_SIZE),
        Normalize(
            mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0)
    ],
                        p=1)

    # Dataset
    '''
    dataset = UkiyoeTrainDataset(
        train_images_path='data',
        train_labels_path='data',
        valid=False,
        confidence_boader=0.87,
        result_path='result/model_effi_b3/efficientnet_b3_980/inference_with_c.csv',
        test_images_path='data',
        over_sampling=False,
        transform_aug=None,
        augmix=False,
        mixup=False,
        transform=transform)
    img, label = dataset[0]
    #print(img.shape)
    #plt.imshow(img)
    #plt.show()
    '''
    # train data loader
    loader = load_train_data(train_images_path='data',
                             train_labels_path='data',
                             batch_size=2,
                             valid=False,
                             nfold=0,
                             transform_aug=None,
                             augmix=True,
                             mixup=False,
                             transform=transform,
                             as_numpy=True)
    image_batch, label_batch = next(loader.__iter__())
    print(image_batch[0].shape)
    print(label_batch[0].shape)
    '''
def data_generator(data=None,
                   meta_data=None,
                   labels=None,
                   batch_size=16,
                   augment={},
                   opt_shuffle=True):

    indices = [i for i in range(len(labels))]

    while True:

        if opt_shuffle:
            shuffle(indices)

        x_data = np.copy(data)
        x_meta_data = np.copy(meta_data)
        x_labels = np.copy(labels)

        for start in range(0, len(labels), batch_size):
            end = min(start + batch_size, len(labels))
            sel_indices = indices[start:end]

            #select data
            data_batch = x_data[sel_indices]
            xm_batch = x_meta_data[sel_indices]
            y_batch = x_labels[sel_indices]
            x_batch = []

            for x in data_batch:

                #augment
                if augment.get('Rotate', False):
                    x = aug.Rotate(x, u=0.1, v=np.random.random())
                    x = aug.Rotate90(x, u=0.1, v=np.random.random())

                if augment.get('Shift', False):
                    x = aug.Shift(x, u=0.05, v=np.random.random())

                if augment.get('Zoom', False):
                    x = aug.Zoom(x, u=0.05, v=np.random.random())

                if augment.get('Flip', False):
                    x = aug.HorizontalFlip(x, u=0.5, v=np.random.random())
                    x = aug.VerticalFlip(x, u=0.5, v=np.random.random())

                x_batch.append(x)

            x_batch = np.array(x_batch, np.float32)

            yield [x_batch, xm_batch], y_batch
예제 #3
0
파일: train.py 프로젝트: kawarasoba/ukiyo
def main(argv=None):

    transform = Compose([
        Resize(cons.IMAGE_SIZE, cons.IMAGE_SIZE),
        Normalize(mean=(0.5, 0.5, 0.5),
                  std=(0.5, 0.5, 0.5),
                  max_pixel_value=255.0)
    ])
    valid_loader = load_train_data(train_images_path=FLAGS.train_images_path,
                                   train_labels_path=FLAGS.train_labels_path,
                                   batch_size=FLAGS.batch_size,
                                   num_worker=FLAGS.num_worker,
                                   valid=True,
                                   nfold=FLAGS.nfold,
                                   transform=transform)

    model = models.get_model(model_name=FLAGS.model_name,
                             num_classes=cons.NUM_CLASSES)
    model.cuda()
    #model = torch.nn.DataParallel(model)

    DIR = '/' + FLAGS.case + '/' + FLAGS.model_name + '/fold' + str(
        FLAGS.nfold)
    RESULT_PATH = ''
    if FLAGS.confidence_border is not None:
        DIR = DIR + '/with_pseudo_labeling'
        RESULT_PATH = RESULT_PATH + FLAGS.result_path
        if FLAGS.result_case is not None:
            RESULT_PATH = RESULT_PATH + '/' + FLAGS.result_case
        RESULT_PATH = RESULT_PATH + '/inference_with_c.csv'

    PARAM_DIR = FLAGS.params_path + DIR
    os.makedirs(PARAM_DIR, exist_ok=True)
    PARAM_NAME = PARAM_DIR + '/' + FLAGS.case
    if FLAGS.executed_epoch > 0:
        TRAINED_PARAM_PATH = FLAGS.restart_param_path + '/' + FLAGS.case + str(
            FLAGS.executed_epoch)
        restart_epoch = FLAGS.executed_epoch + 1
        if FLAGS.restart_from_final:
            TRAINED_PARAM_PATH = TRAINED_PARAM_PATH + '_final'
        TRAINED_PARAM_PATH = TRAINED_PARAM_PATH + '.pth'
        model.load_state_dict(torch.load(TRAINED_PARAM_PATH))
    else:
        restart_epoch = 0

    optimizer = optim.Adam(model.parameters(), lr=cons.start_lr)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=FLAGS.opt_level)

    if FLAGS.add_class_weight:
        loader = load_train_data(train_images_path=FLAGS.train_images_path,
                                 train_labels_path=FLAGS.train_labels_path,
                                 batch_size=FLAGS.batch_size,
                                 num_worker=FLAGS.num_worker,
                                 nfold=FLAGS.nfold)
        count_label = np.zeros(10, dtype=np.int64)
        for feed in loader:
            _, labels = feed
            count_label += np.sum(labels.numpy().astype(np.int64), axis=0)
        weight = torch.from_numpy(count_label).cuda()
    else:
        weight = None
    criterion = nn.BCEWithLogitsLoss(weight=weight)

    writer = SummaryWriter(log_dir=FLAGS.logs_path + DIR + '/tensorboardX/')
    best_acc = 0

    if FLAGS.augmentation and FLAGS.aug_decrease:
        p = 0.5

        for e in range(restart_epoch, FLAGS.final_epoch):
            p_partical = p * (FLAGS.final_epoch - e) / FLAGS.final_epoch

            lr = set_lr.cosine_annealing(optimizer, cons.start_lr, e, 100)
            writer.add_scalar('LearningRate', lr, e)

            train_loader = load_train_data(
                train_images_path=FLAGS.train_images_path,
                train_labels_path=FLAGS.train_labels_path,
                batch_size=FLAGS.batch_size,
                num_worker=FLAGS.num_worker,
                nfold=FLAGS.nfold,
                confidence_border=FLAGS.confidence_border,
                result_path=RESULT_PATH,
                test_images_path=FLAGS.test_images_path,
                over_sampling=FLAGS.over_sampling,
                transform_aug=Compose([
                    aug.HueSaturationValue(p=p_partical),
                    aug.RandomBrightnessContrast(p=p_partical),
                    aug.CLAHE(p=p_partical),
                    aug.JpegCompression(p=p_partical),
                    aug.GaussNoise(p=p),
                    aug.MedianBlur(p=p),
                    aug.ElasticTransform(p=p_partical),
                    aug.HorizontalFlip(p=p),
                    aug.Rotate(p=p),
                    aug.CoarseDropout(p=p_partical),
                    aug.RandomSizedCrop(p=p)
                ]),
                mixup=FLAGS.mixup,
                transform=transform)

            train_loss = train_loop(model, train_loader, criterion, optimizer)
            writer.add_scalar('train_loss', train_loss, e)

            valid_loss, valid_acc = valid_loop(model, valid_loader, criterion)
            writer.add_scalar('valid_loss', valid_loss, e)
            writer.add_scalar('valid_acc', valid_acc, e)

            print(
                'Epoch: {}, Train Loss: {:.4f}, Valid Loss: {:.4f}, Valid Accuracy:{:.2f}'
                .format(e + 1, train_loss, valid_loss, valid_acc))
            if e % 10 == 0:
                torch.save(model.state_dict(),
                           PARAM_NAME + '_' + str(e) + '.pth')
            if valid_acc > best_acc:
                best_acc = valid_acc
                torch.save(model.state_dict(), PARAM_NAME + '_best.pth')
    else:

        if FLAGS.augmentation and not FLAGS.augmix:
            transform_aug = Compose([
                aug.HueSaturationValue(),
                aug.RandomBrightnessContrast(),
                aug.CLAHE(),
                aug.JpegCompression(),
                aug.GaussNoise(),
                aug.MedianBlur(),
                aug.ElasticTransform(),
                aug.HorizontalFlip(),
                aug.Rotate(),
                aug.CoarseDropout(),
                aug.RandomSizedCrop()
            ])
        else:
            transform_aug = None

        train_loader = load_train_data(
            train_images_path=FLAGS.train_images_path,
            train_labels_path=FLAGS.train_labels_path,
            batch_size=FLAGS.batch_size,
            num_worker=FLAGS.num_worker,
            valid=False,
            nfold=FLAGS.nfold,
            over_sampling=FLAGS.over_sampling,
            transform_aug=transform_aug,
            augmix=FLAGS.augmix,
            mixup=FLAGS.mixup,
            transform=transform)

        total_time = 0
        for e in range(restart_epoch, FLAGS.final_epoch):
            start = time.time()
            lr = set_lr.cosine_annealing(optimizer, cons.start_lr, e, 100)
            writer.add_scalar('LearningRate', lr, e)
            train_loss = train_loop(model, train_loader, criterion, optimizer)
            writer.add_scalar('train_loss', train_loss, e)
            valid_loss, valid_acc = valid_loop(model, valid_loader, criterion)
            writer.add_scalar('valid_loss', valid_loss, e)
            writer.add_scalar('valid_acc', valid_acc, e)
            print(
                'Epoch: {}, Train Loss: {:.4f}, Valid Loss: {:.4f}, Valid Accuracy:{:.2f}'
                .format(e + 1, train_loss, valid_loss, valid_acc))
            if e % 10 == 0:
                torch.save(model.state_dict(),
                           PARAM_NAME + '_' + str(e) + '.pth')
            if valid_acc > best_acc:
                best_acc = valid_acc
                torch.save(model.state_dict(), PARAM_NAME + '_best.pth')
            total_time = total_time + (time.time() - start)
            print('average time: {}[sec]'.format(total_time / (e + 1)))

    torch.save(model.state_dict(),
               PARAM_NAME + '_' + str(FLAGS.final_epoch - 1) + '_final.pth')