예제 #1
0
def get(args):
    # load dataset
    dataset_train, dataset_valid = dataset_cifar.get_dataset("cifar100")
    # print(len(dataset_train), len(dataset_valid))

    split = 0.0
    split_idx = 0
    train_sampler = None
    if split > 0.0:
        sss = StratifiedShuffleSplit(n_splits=5,
                                     test_size=split,
                                     random_state=0)
        sss = sss.split(list(range(len(dataset_train))), dataset_train.targets)
        for _ in range(split_idx + 1):
            train_idx, valid_idx = next(sss)
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetSampler(valid_idx)
    else:
        valid_sampler = SubsetSampler([])

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.train_batch_size,
        shuffle=True if train_sampler is None else False,
        num_workers=32,
        pin_memory=True,
        sampler=train_sampler,
        drop_last=True)  # 32

    valid_loader = torch.utils.data.DataLoader(dataset_valid,
                                               batch_size=args.test_batch_size,
                                               shuffle=False,
                                               num_workers=16,
                                               pin_memory=True,
                                               drop_last=True)  # 16

    train_dataprovider = DataIterator(train_loader)
    val_dataprovider = DataIterator(valid_loader)
    # args.test_interval = len(valid_loader)
    # args.val_interval = int(len(dataset_train) / args.batch_size)  # step
    print(
        'load valid dataset successfully, and images:{}, batchsize:{}, step:{}'
        .format(len(dataset_valid), args.test_batch_size,
                len(dataset_valid) // args.test_batch_size))
    return train_dataprovider, val_dataprovider, len(
        dataset_valid) // args.test_batch_size
예제 #2
0
def main():
    with tf.Session() as sess:
        num_epoch = 5
        checkpoint_interval = 10

        batch_size = 64
        image_size = 32

        model = LAPGAN(sess, batch_size=batch_size)

        dataset = Dataset("cifar10/")
        dataset_iter = DataIterator(dataset.train_images, dataset.train_labels, batch_size)

        summary_writer = tf.train.SummaryWriter('logs_{0}/'.format(int(time.time())), sess.graph_def)

        sess.run(tf.initialize_all_variables())

        sample_images = dataset.valid_images[:model.sample_size].astype(np.float32) / 255.0
        sample_z = np.random.uniform(-1.0, 1.0, size=(model.sample_size , model.z_dim))

        d_overpowered = False

        step = 0
        for epoch in range(num_epoch):
            for batch_images, _ in dataset_iter.iterate():
                # I0 = batch_images / 255.0
                # I1 = downsample(tf.constant(I0, tf.float32))
                # l0 = sess.run(upsample(I1))
                # h0 = I0 - l0
                # z0 = np.random.uniform(-1.0, 1.0, (batch_size,) + image_size + (1,)).astype(np.float32)
                # l0 = np.concatenate([l0, z0], axis=-1)

                batch_images = batch_images.astype(np.float32) / 255.0
                batch_z = np.random.uniform(-1.0, 1.0, [batch_size, model.z_dim]).astype(np.float32)

                # update d network
                if not d_overpowered:
                    sess.run(model.d_optim, feed_dict={ model.x: batch_images, model.z: batch_z })

                # update g network
                sess.run(model.g_optim, feed_dict={ model.z: batch_z })

                if step % checkpoint_interval == 0:
                    # I0 = dataset.valid_images / 255.0
                    # I1 = downsample(tf.constant(I0, tf.float32))
                    # l0 = sess.run(upsample(I1))
                    # h0 = I0 - l0
                    # z0 = np.random.uniform(-1.0, 1.0, I0.shape[:-1] + (1,)).astype(np.float32)
                    # l0 = np.concatenate([l0, z0], axis=-1)

                    batch_images = dataset.valid_images[:batch_size].astype(np.float32) / 255.0
                    batch_z = np.random.uniform(-1.0, 1.0, [batch_size, model.z_dim]).astype(np.float32)

                    d_loss, g_loss, summary = sess.run([
                        model.d_loss,
                        model.g_loss,
                        model.merged
                    ], feed_dict={
                        model.x: batch_images,
                        model.z: batch_z
                    })

                    d_overpowered = d_loss < g_loss / 2

                    samples = sess.run(model.G, feed_dict={
                        model.x: sample_images,
                        model.z: sample_z
                    })

                    summary_writer.add_summary(summary, step)
                    save_images(samples, [8, 8], './samples/train_{0}_{1}.png'.format(epoch, step))
                    print('[{0}, {1}] loss: {2} (D) {3} (G) (d overpowered?: {4})'.format(epoch, step, d_loss, g_loss, d_overpowered))

                step += 1
예제 #3
0
    tf.scalar_summary("loss_d_fake", loss_d_fake)
    tf.scalar_summary("loss_d", loss_d)
    tf.scalar_summary("loss_g", loss_g)

    vars_g = [var for var in tf.trainable_variables() if 'generator' in var.name]
    vars_d = [var for var in tf.trainable_variables() if 'discrimin' in var.name]

    train_g = tf.train.GradientDescentOptimizer(2e-4).minimize(loss_g, var_list=vars_g)
    train_d = tf.train.GradientDescentOptimizer(2e-4).minimize(loss_d, var_list=vars_d)

    merged_summaries = tf.merge_all_summaries()

    sess.run(tf.initialize_all_variables())

    dataset = Dataset("cifar10/")
    dataset_iter = DataIterator(dataset.train_images, dataset.train_labels, batch_size)

    print('train', dataset.train_images.shape, dataset.train_labels.shape)
    print('valid', dataset.valid_images.shape, dataset.valid_labels.shape)
    print('test', dataset.test_images.shape, dataset.test_labels.shape)

    summary_writer = tf.train.SummaryWriter('logs_{0}/'.format(int(time.time())), sess.graph_def)

    for epoch in range(num_epochs):
        if epoch % checkpoint_interval == 0:
            I0 = dataset.valid_images / 255.0
            I1 = downsample(tf.constant(I0, tf.float32))
            l0 = sess.run(upsample(I1))
            h0 = I0 - l0

            z0 = np.random.uniform(-1.0, 1.0, I0.shape[:-1] + (1,)).astype(np.float32)
예제 #4
0
파일: train.py 프로젝트: wangh-allen/sCNNs
def main():
    start_time = time.time()  # Clocking start
    # Div2K - Track 1: Bicubic downscaling - x4 DataSet load
    if data_from == 'img':
        ds = DataSet(ds_path=config.data_dir,
                     ds_name="X4",
                     use_save=True,
                     save_type="to_h5",
                     save_file_name=config.data_dir + "DIV2K",
                     use_img_scale=False,
                     n_patch=config.patch_size)
    else:  # .h5 files
        ds = DataSet(ds_hr_path=config.data_dir + "DIV2K-hr.h5",
                     ds_lr_path=config.data_dir + "DIV2K-lr.h5",
                     use_img_scale=False,
                     n_patch=config.patch_size)

    # [0, 1] scaled images
    if config.patch_size > 0:
        hr, lr = ds.patch_hr_images, ds.patch_lr_images
    else:
        hr, lr = ds.hr_images, ds.lr_images

    lr_shape = lr.shape[1:]
    hr_shape = hr.shape[1:]
    print("[+] Loaded LR patch image ", lr.shape)
    print("[+] Loaded HR patch image ", hr.shape)

    # setup directory
    if not os.path.exists(config.output_dir):
        os.mkdir(config.output_dir)

    # sample LR image
    if config.patch_size > 0:
        patch = int(np.sqrt(config.patch_size))

        rnd = np.random.randint(0, ds.n_images)

        sample_lr = lr[config.patch_size * rnd:config.patch_size *
                       (rnd + 1), :, :, :]
        sample_lr = np.reshape(sample_lr, (config.patch_size, ) +
                               lr_shape)  # (16,) + lr_shape

        sample_hr = hr[config.patch_size * rnd:config.patch_size *
                       (rnd + 1), :, :, :]
        sample_hr = np.reshape(sample_hr, (config.patch_size, ) +
                               hr_shape)  # (16,) + hr_shape

        util.img_save(
            img=util.merge(sample_lr, (patch, patch)),
            path=config.output_dir + "/sample_lr.png",
            use_inverse=False,
        )
        util.img_save(
            img=util.merge(sample_hr, (patch, patch)),
            path=config.output_dir + "/sample_hr.png",
            use_inverse=False,
        )
    else:
        rnd = np.random.randint(0, ds.n_images)

        sample_lr = lr[rnd]
        sample_lr = np.reshape(sample_lr, lr_shape)  # lr_shape

        sample_hr = hr[rnd]
        sample_hr = np.reshape(sample_hr, hr_shape)  # hr_shape

        util.img_save(
            img=sample_lr,
            path=config.output_dir + "/sample_lr.png",
            use_inverse=False,
        )
        util.img_save(
            img=sample_hr,
            path=config.output_dir + "/sample_hr.png",
            use_inverse=False,
        )
        # scaling into lr [0, 1]
        sample_lr /= 255.

    # DataIterator
    di = DataIterator(lr, hr, config.batch_size)
    rcan_model = model.RCAN(
        lr_img_size=lr_shape[:-1],
        hr_img_size=hr_shape[:-1],
        batch_size=config.batch_size,
        img_scaling_factor=config.image_scaling_factor,
        n_res_blocks=config.n_res_blocks,
        n_res_groups=config.n_res_groups,
        res_scale=config.res_scale,
        n_filters=config.filter_size,
        kernel_size=config.kernel_size,
        activation=config.activation,
        use_bn=config.use_bn,
        reduction=config.reduction,
        optimizer=config.optimizer,
        lr=config.lr,
        lr_decay=config.lr_decay,
        lr_decay_step=config.lr_decay_step,
        momentum=config.momentum,
        beta1=config.beta1,
        beta2=config.beta2,
        opt_eps=config.opt_epsilon,
        tf_log=config.summary,
        n_gpu=config.n_gpu,
    )
    # gpu config
    gpu_config = tf.GPUOptions(allow_growth=True)
    tf_config = tf.ConfigProto(allow_soft_placement=True,
                               log_device_placement=False,
                               gpu_options=gpu_config)

    with tf.Session(config=tf_config) as sess:

        # Initializing
        writer = tf.summary.FileWriter(config.summary, sess.graph)
        sess.run(tf.global_variables_initializer())

        # Load model & Graph & Weights
        global_step = 0
        ckpt = tf.train.get_checkpoint_state(config.summary)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            rcan_model.saver.restore(sess, ckpt.model_checkpoint_path)
            global_step = int(
                ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
            print("[+] global step : %d" % global_step, " successfully loaded")
        else:
            print('[-] No checkpoint file found')

        # config params
        lr = config.lr if global_step < config.lr_decay_step \
            else config.lr * (config.lr_decay * (global_step // config.lr_decay_step))

        rcan_model.global_step.assign(tf.constant(global_step))
        start_epoch = global_step // (ds.n_images // config.batch_size)

        best_loss = 1e8
        for epoch in range(start_epoch, config.epochs):
            for x_lr, x_hr in di.iterate():
                # scaling into lr [0, 1] # hr [0, 255]
                x_lr = np.true_divide(x_lr, 255., casting='unsafe')
                # training
                _, loss, psnr, ssim = sess.run(
                    [
                        rcan_model.train_op, rcan_model.loss, rcan_model.psnr,
                        rcan_model.ssim
                    ],
                    feed_dict={
                        rcan_model.x_lr: x_lr,
                        rcan_model.x_hr: x_hr,
                        rcan_model.lr: lr,
                    })

                if global_step % config.logging_step == 0:
                    print(
                        "[+] %d epochs %d steps" % (epoch, global_step),
                        "loss : {:.8f} PSNR : {:.4f} SSIM : {:.4f}".format(
                            loss, psnr, ssim))
                    # summary & output
                    summary, output = sess.run(
                        [rcan_model.merged, rcan_model.output],
                        feed_dict={
                            rcan_model.x_lr: x_lr,
                            rcan_model.x_hr: x_hr,
                            rcan_model.lr: lr,
                        })
                    writer.add_summary(summary, global_step)
                    # model save
                    rcan_model.saver.save(sess, config.summary, global_step)

                    if loss < best_loss:
                        print("[*] improved {:.8f} to {:.8f}".format(
                            best_loss, loss))
                        rcan_model.best_saver.save(sess, './best/',
                                                   global_step)
                        best_loss = loss

                if global_step % (config.logging_step * 10) == 0:
                    util.img_save(img=util.merge(output, (patch, patch)),
                                  path=config.output_dir +
                                  "/%d.png" % global_step,
                                  use_inverse=False)
                # lr schedule
                if global_step and global_step % config.lr_decay_step == 0:
                    lr *= config.lr_decay

                # increase global step
                rcan_model.global_step.assign_add(tf.constant(1))
                global_step += 1

    end_time = time.time() - start_time  # Clocking end

    # Elapsed time
    print("[+] Elapsed time {:.8f}s".format(end_time))
예제 #5
0
파일: test.py 프로젝트: wangh-allen/sCNNs
def main():

    if data_from == 'img':
        ds = DataSet(ds_path=config.test_dir,
                     ds_name="X4",
                     use_save=True,
                     save_type="to_h5",
                     save_file_name=config.test_dir + "DIV2K",
                     use_img_scale=False,
                     n_patch=config.patch_size,
                     n_images=100,
                     is_train=False)
    else:  # .h5 files
        ds = DataSet(ds_hr_path=config.test_dir + "DIV2K-hr.h5",
                     ds_lr_path=config.test_dir + "DIV2K-lr.h5",
                     use_img_scale=False,
                     n_patch=config.patch_size,
                     n_images=100,
                     is_train=False)

    # [0, 1] scaled images
    if config.patch_size > 0:
        hr, lr = ds.patch_hr_images, ds.patch_lr_images
    else:
        hr, lr = ds.hr_images, ds.lr_images

    lr_shape = lr.shape[1:]
    hr_shape = hr.shape[1:]
    print("[+] Loaded LR patch image ", lr.shape)
    print("[+] Loaded HR patch image ", hr.shape)

    di = DataIterator(lr, hr, config.batch_size)
    rcan_model = model.RCAN(
        lr_img_size=lr_shape[:-1],
        hr_img_size=hr_shape[:-1],
        batch_size=config.batch_size,
        img_scaling_factor=config.image_scaling_factor,
        n_res_blocks=config.n_res_blocks,
        n_res_groups=config.n_res_groups,
        res_scale=config.res_scale,
        n_filters=config.filter_size,
        kernel_size=config.kernel_size,
        activation=config.activation,
        use_bn=config.use_bn,
        reduction=config.reduction,
        optimizer=config.optimizer,
        lr=config.lr,
        lr_decay=config.lr_decay,
        lr_decay_step=config.lr_decay_step,
        momentum=config.momentum,
        beta1=config.beta1,
        beta2=config.beta2,
        opt_eps=config.opt_epsilon,
        tf_log=config.summary,
        n_gpu=config.n_gpu,
    )
    # gpu config
    gpu_config = tf.GPUOptions(allow_growth=True)
    tf_config = tf.ConfigProto(allow_soft_placement=True,
                               log_device_placement=False,
                               gpu_options=gpu_config)

    with tf.Session(config=tf_config) as sess:
        # Initializing
        writer = tf.summary.FileWriter(config.test_log, sess.graph)
        sess.run(tf.global_variables_initializer())

        # Load model & Graph & Weights
        ckpt = tf.train.get_checkpoint_state(config.summary)
        if ckpt and ckpt.model_checkpoint_path:
            rcan_model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise OSError("[-] No checkpoint file found")

        # get result
        total_psnr = []
        total_ssim = []
        i = 0
        for x_lr, x_hr in di.iterate():
            x_lr = np.true_divide(x_lr, 255., casting='unsafe')

            psnr, ssim, summary, output = sess.run(
                [
                    rcan_model.psnr, rcan_model.ssim, rcan_model.merged,
                    rcan_model.output
                ],
                feed_dict={
                    rcan_model.x_lr: x_lr,
                    rcan_model.x_hr: x_hr,
                    rcan_model.lr: config.lr,
                })
            # output = np.reshape(output, rcan_model.hr_img_size)  # (384, 384, 3)
            writer.add_summary(summary)
            total_psnr.append(psnr)
            total_ssim.append(ssim)

            # save result
            patch = int(np.sqrt(config.patch_size))
            img_save(merge(output, (patch, patch)),
                     './output/test' + '/%d.png' % i,
                     use_inverse=False)
            print("%d images tested, " % i,
                  "PSNR : {:.4f} SSIM : {:.4f}".format(psnr, ssim))
            i += 1
        print("total PSNR is {:.4f}, SSIM is {:.4f}".format(
            sum(total_psnr) / len(total_psnr),
            sum(total_ssim) / len(total_ssim)))
예제 #6
0
파일: train.py 프로젝트: Wedeueis/Projetos
def main():
    start_time = time.time()  # Clocking start

    # GPU configure
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as s:
        # DCGAN Model
        model = prog_gan.ProgGAN(s)
        # Dataset load
        dataset = Dataset("../data/CIFAR10_data/")
        dataset_iter = DataIterator(dataset.train_images, dataset.train_labels,
                                    model.batch_size)

        # Initializing
        s.run(tf.global_variables_initializer())

        sample_x = dataset.valid_images[:model.sample_num].astype(
            np.float32) / 255.0
        sample_z = np.random.uniform(
            -1., 1., [model.sample_num, model.z_dim]).astype(np.float32)

        d_overpowered = False

        step = 0
        for epoch in range(train_step['epoch']):
            for batch_images, _ in dataset_iter.iterate():
                batch_x = batch_images.astype(np.float32) / 255.0
                batch_z = np.random.uniform(
                    -1., 1.,
                    [model.batch_size, model.z_dim]).astype(np.float32)

                # Update D network
                if not d_overpowered:
                    _, d_loss = s.run([model.d_op, model.d_loss],
                                      feed_dict={
                                          model.x: batch_x,
                                          model.z: batch_z,
                                      })

                # Update G network
                _, g_loss = s.run([model.g_op, model.g_loss],
                                  feed_dict={
                                      model.x: batch_x,
                                      model.z: batch_z,
                                  })

                d_overpowered = d_loss < g_loss / 2

                # Logging
                if step % train_step['logging_interval'] == 0:
                    batch_x = dataset.valid_images[:model.batch_size].astype(
                        np.float32) / 255.0
                    batch_z = np.random.uniform(
                        -1., 1.,
                        [model.batch_size, model.z_dim]).astype(np.float32)

                    d_loss, g_loss, summary = s.run(
                        [model.d_loss, model.g_loss, model.merged],
                        feed_dict={
                            model.x: batch_x,
                            model.z: batch_z,
                        })

                    d_overpowered = d_loss < g_loss / 2

                    # Print loss
                    print("[+] Step %08d => " % step,
                          "Dloss: {:.8f}".format(d_loss),
                          "Gloss: {:.8f}".format(g_loss))

                    # Training G model with sample image and noise
                    samples = s.run(model.g,
                                    feed_dict={
                                        model.x: sample_x,
                                        model.z: sample_z,
                                    })

                    # Summary saver
                    model.writer.add_summary(summary, step)

                    # Export image generated by model G
                    sample_image_height = model.sample_size
                    sample_image_width = model.sample_size
                    sample_dir = results['output'] + 'train_{:08d}.png'.format(
                        step)

                    # Generated image save
                    iu.save_images(
                        samples,
                        size=[sample_image_height, sample_image_width],
                        image_path=sample_dir)

                    # Model save
                    model.saver.save(s, results['model'], global_step=step)

                step += 1

    end_time = time.time() - start_time  # Clocking end

    # Elapsed time
    print("[+] Elapsed time {:.8f}s".format(end_time))

    # Close tf.Session
    s.close()
예제 #7
0
def prepare():
    args = get_args()

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    if args.cifar100:
        train_dataprovider, val_dataprovider, train_step, valid_step = dataset_cifar.get_dataset(
            "cifar100", batch_size=args.batch_size, RandA=args.randAugment)
        print('load data successfully')
    else:
        assert os.path.exists(args.train_dir)
        from dataset import DataIterator, SubsetSampler, OpencvResize, ToBGRTensor
        train_dataset = datasets.ImageFolder(
            args.train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ColorJitter(brightness=0.4,
                                       contrast=0.4,
                                       saturation=0.4),
                transforms.RandomHorizontalFlip(0.5),
                ToBGRTensor(),
            ]))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=1,
                                                   pin_memory=use_gpu)
        train_dataprovider = DataIterator(train_loader)

        assert os.path.exists(args.val_dir)
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            args.val_dir,
            transforms.Compose([
                OpencvResize(256),
                transforms.CenterCrop(224),
                ToBGRTensor(),
            ])),
                                                 batch_size=200,
                                                 shuffle=False,
                                                 num_workers=1,
                                                 pin_memory=use_gpu)
        val_dataprovider = DataIterator(val_loader)
        print('load data successfully')

    # Imagenet
    # from network import ShuffleNetV2_OneShot
    # model = ShuffleNetV2_OneShot(n_class=1000)

    # Special for cifar
    from network_origin import cifar_fast
    model = cifar_fast(input_size=32, n_class=100)

    # Optimizer
    optimizer = get_optim(args, model)

    # Label Smooth
    if args.criterion_smooth:
        criterion = CrossEntropyLabelSmooth(100, 0.1)
    else:
        criterion = nn.CrossEntropyLoss()

    if args.lr_scheduler == 'Lambda':
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda step: (1.0 - step / (args.epochs * train_step))
            if step <= (args.epochs * train_step) else 0,
            last_epoch=-1)
    elif args.lr_scheduler == 'Cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs, eta_min=1e-8, last_epoch=-1)

    if use_gpu:
        model = nn.DataParallel(model)
        cudnn.benchmark = True
        loss_function = criterion.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion
        device = torch.device("cpu")
    model = model.to(device)

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider
    args.best_acc = 0.0
    args.all_iters = 1

    start_epoch = 1
    for epoch in range(start_epoch, start_epoch + args.epochs):
        loss_output, train_acc = train_nni(args, model, device, epoch,
                                           train_step)
        acc, best_acc = test_nni(args, model, device, epoch, valid_step)
        print(
            'Epoch {}, loss/train acc = {:.2f}/{:.2f}, val acc/best acc = {:.2f}/{:.2f},'
            .format(epoch, loss_output, train_acc, acc, best_acc))
def prepare(args, RCV_CONFIG):
    args.momentum = RCV_CONFIG['momentum']
    args.bn_process = True if RCV_CONFIG['bn_process'] == 'True' else False
    args.learning_rate = RCV_CONFIG['learning_rate']
    args.weight_decay = RCV_CONFIG['weight_decay']
    args.label_smooth = RCV_CONFIG['label_smooth']
    args.lr_scheduler = RCV_CONFIG['lr_scheduler']
    args.randAugment = True if RCV_CONFIG['randAugment'] == 'True' else False

    # if RCV_CONFIG['momentum'] == 'vgg':
    #     net = VGG('VGG19')
    # if RCV_CONFIG['model'] == 'resnet18':
    #     net = ResNet18()
    # if RCV_CONFIG['model'] == 'googlenet':
    #     net = GoogLeNet()

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    if args.cifar100:
        # train_dataprovider, val_dataprovider, train_step, valid_step = dataset_cifar.get_dataset("cifar100", batch_size=args.batch_size, RandA=args.randAugment)

        train_dataprovider, val_dataprovider, train_step, valid_step = dataset_cifar.get_dataset(
            "cifar10", batch_size=args.batch_size, RandA=args.randAugment)
        print('load data successfully')
    else:
        assert os.path.exists(args.train_dir)
        from dataset import DataIterator, SubsetSampler, OpencvResize, ToBGRTensor
        train_dataset = datasets.ImageFolder(
            args.train_dir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ColorJitter(brightness=0.4,
                                       contrast=0.4,
                                       saturation=0.4),
                transforms.RandomHorizontalFlip(0.5),
                ToBGRTensor(),
            ]))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=1,
                                                   pin_memory=use_gpu)
        train_dataprovider = DataIterator(train_loader)

        assert os.path.exists(args.val_dir)
        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            args.val_dir,
            transforms.Compose([
                OpencvResize(256),
                transforms.CenterCrop(224),
                ToBGRTensor(),
            ])),
                                                 batch_size=200,
                                                 shuffle=False,
                                                 num_workers=1,
                                                 pin_memory=use_gpu)
        val_dataprovider = DataIterator(val_loader)
        print('load data successfully')

    # Imagenet
    # from network import ShuffleNetV2_OneShot
    # model = ShuffleNetV2_OneShot(n_class=1000)

    # Special for cifar
    from network_origin import cifar_fast
    model = cifar_fast(input_size=32, n_class=100)

    # Optimizer
    optimizer = get_optim(args, model)

    # Label Smooth
    if args.label_smooth > 0:
        criterion = CrossEntropyLabelSmooth(100, args.label_smooth)
    else:
        # print('CrossEntropyLoss')
        criterion = nn.CrossEntropyLoss()

    if args.lr_scheduler == 'Lambda':
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda step: (1.0 - step / (args.epochs * train_step))
            if step <= (args.epochs * train_step) else 0,
            last_epoch=-1)
    elif args.lr_scheduler == 'Cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs, eta_min=1e-8, last_epoch=-1)

    if use_gpu:
        model = nn.DataParallel(model)
        cudnn.benchmark = True
        loss_function = criterion.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion
        device = torch.device("cpu")
    model = model.to(device)

    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_dataprovider = train_dataprovider
    args.val_dataprovider = val_dataprovider
    args.best_acc = 0.0
    args.all_iters = 1

    return model, device, train_step, valid_step