Example #1
0
def test(FLAG):
    print("Reading dataset...")
    if FLAG.dataset == 'CIFAR-10':
        dataset = CIFAR10(train=False)
    elif FLAG.dataset == 'CIFAR-100':
        dataset = CIFAR100(train=False)
    else:
        raise ValueError("dataset should be either CIFAR-10 or CIFAR-100.")

    Xtest, Ytest = dataset.test_data, dataset.test_labels

    print("Build VGG16 models...")
    dp = [(i + 1) * 0.05 for i in range(1, 20)]
    vgg16 = VGG16(FLAG.init_from, infer=True, prof_type=FLAG.prof_type)
    vgg16.build(dp=dp)

    with tf.Session() as sess:
        if FLAG.save_dir is not None:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(FLAG.save_dir)

            if ckpt and ckpt.model_checkpoint_path:
                count = 0
                for checkpoint in ckpt.all_model_checkpoint_paths:
                    saver.restore(sess, checkpoint)
                    print("Model restored %s" % checkpoint)
                    sess.run(tf.global_variables())
                    print("Initialized")

                    #     saver.restore(sess, ckpt.model_checkpoint_path)
                    #     print("Model restored %s" % ckpt.model_checkpoint_path)
                    #     sess.run(tf.global_variables())
                    # print("Initialized")
                    count += 1
                    output = []
                    for dp_i in dp:
                        accu = sess.run(vgg16.accu_dict[str(int(dp_i * 100))],
                                        feed_dict={
                                            vgg16.x: Xtest[:5000, :],
                                            vgg16.y: Ytest[:5000, :]
                                        })
                        accu2 = sess.run(vgg16.accu_dict[str(int(dp_i * 100))],
                                         feed_dict={
                                             vgg16.x: Xtest[5000:, :],
                                             vgg16.y: Ytest[5000:, :]
                                         })
                        output.append((accu + accu2) / 2)
                        print("At DP={dp:.4f}, accu={perf:.4f}".format(
                            dp=dp_i, perf=(accu + accu2) / 2))
                    res = pd.DataFrame.from_dict({
                        'DP': [int(dp_i * 100) for dp_i in dp],
                        'accu':
                        output
                    })
                    res.to_csv("task%s_%s" % (count, FLAG.output), index=False)
                    print("Write into task%s_%s" % (count, FLAG.output))
Example #2
0
def test(FLAG):
    print("Reading dataset...")
    if FLAG.dataset == 'CIFAR-10':
        test_data = CIFAR10(train=False)
        vgg16 = VGG16(classes=10)
    elif FLAG.dataset == 'CIFAR-100':
        test_data = CIFAR100(train=False)
        vgg16 = VGG16(classes=100)
    else:
        raise ValueError("dataset should be either CIFAR-10 or CIFAR-100.")

    Xtest, Ytest = test_data.test_data, test_data.test_labels

    if FLAG.fidelity is not None:
        data_dict = np.load(FLAG.init_from, encoding='latin1').item()
        data_dict = dpSparsifyVGG16(data_dict, FLAG.fidelity)
        vgg16.build(vgg16_npy_path=data_dict,
                    conv_pre_training=True,
                    fc_pre_training=True)
        print("Build model from %s using dp=%s" %
              (FLAG.init_from, str(FLAG.fidelity * 100)))
    else:
        vgg16.build(vgg16_npy_path=FLAG.init_from,
                    conv_pre_training=True,
                    fc_pre_training=True)
        print("Build full model from %s" % (FLAG.init_from))

    # build model using  dp
    # dp = [(i+1)*0.05 for i in range(1,20)]
    dp = [1.0]
    vgg16.set_idp_operation(dp=dp, keep_prob=FLAG.keep_prob)

    flops, params = countFlopsParas(vgg16)
    print("Flops: %3f M, Paras: %3f M" % (flops / 1e6, params / 1e6))
    FLAG.flops_M = flops / 1e6
    FLAG.params_M = params / 1e6

    with tf.Session() as sess:
        if FLAG.save_dir is not None:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(FLAG.save_dir)

            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, checkpoint)
                print("Model restored %s" % checkpoint)
                sess.run(tf.global_variables())
            else:
                print("No model checkpoint in %s" % FLAG.save_dir)
        else:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.global_variables())
        print("Initialized")
        output = []
        for dp_i in dp:
            accu = sess.run(vgg16.accu_dict[str(int(dp_i * 100))],
                            feed_dict={
                                vgg16.x: Xtest[:5000, :],
                                vgg16.y: Ytest[:5000, :],
                                vgg16.is_train: False
                            })
            accu2 = sess.run(vgg16.accu_dict[str(int(dp_i * 100))],
                             feed_dict={
                                 vgg16.x: Xtest[5000:, :],
                                 vgg16.y: Ytest[5000:, :],
                                 vgg16.is_train: False
                             })
            output.append((accu + accu2) / 2)
            print("At DP={dp:.4f}, accu={perf:.4f}".format(
                dp=dp_i * FLAG.fidelity, perf=(accu + accu2) / 2))
        res = pd.DataFrame.from_dict({
            'DP': [int(dp_i * 100) for dp_i in dp],
            'accu': output
        })
        res.to_csv(FLAG.output, index=False)
        print("Write into %s" % FLAG.output)

    FLAG.accuracy = (accu + accu2) / 2

    header = ''
    row = ''
    for key in sorted(vars(FLAG)):
        if header is '':
            header = key
            row = str(getattr(FLAG, key))
        else:
            header += "," + key
            row += "," + str(getattr(FLAG, key))
    row += "\n"
    header += "\n"
    if os.path.exists("/home/cmchang/new_CP_CNN/performance.csv"):
        with open("/home/cmchang/new_CP_CNN/performance.csv", "a") as myfile:
            myfile.write(row)
    else:
        with open("/home/cmchang/new_CP_CNN/performance.csv", "w") as myfile:
            myfile.write(header)
            myfile.write(row)
Example #3
0
def mkdir(path):
    if not os.path.isdir(path):
        os.mkdir(path)


if __name__ == '__main__':
    args = parse_args()
    cfg = config_from_file(args.config)
    data_dir = cfg.DATA_DIR
    arch = args.arch
    attention = args.attention
    output_dir = args.output_dir
    mkdir(output_dir)
    checkpoint = cfg.CHECKPOINT
    save_dir = os.path.join(checkpoint, '_'.join([arch, attention]))
    test_dataset = CIFAR100(data_dir, is_test=True, augmentation=False)
    if arch != 'None':
        model = create_model(arch, attention)
        load_point = os.path.join(save_dir,
                                  'model_{}.pth'.format(args.load_epoch))
        saved_dict = torch.load(load_point)
        state_dict = dict_nnDataParallel(saved_dict['state_dict'])
        model.load_state_dict(state_dict)
        model.eval()
        for i in range(8):
            img, label = test_dataset[i]
            img = torch.unsqueeze(img, 0)
            heatmap, pred = get_gradCAM_heatmap(img, model, label)
            sp_img = view_gradCAM(img, heatmap)
            output_file = os.path.join(
                output_dir,
Example #4
0
def train(FLAG):
    print("Reading dataset...")
    if FLAG.dataset == 'CIFAR-10':
        train_data = CIFAR10(train=True)
        test_data = CIFAR10(train=False)
        vgg16 = VGG16(classes=10)
    elif FLAG.dataset == 'CIFAR-100':
        train_data = CIFAR100(train=True)
        test_data = CIFAR100(train=False)
        vgg16 = VGG16(classes=100)
    else:
        raise ValueError("dataset should be either CIFAR-10 or CIFAR-100.")
    print("Build VGG16 models for %s..." % FLAG.dataset)

    Xtrain, Ytrain = train_data.train_data, train_data.train_labels
    Xtest, Ytest = test_data.test_data, test_data.test_labels

    vgg16.build(vgg16_npy_path=FLAG.init_from,
                prof_type=FLAG.prof_type,
                conv_pre_training=True,
                fc_pre_training=False)
    vgg16.sparsity_train(l1_gamma=FLAG.lambda_s,
                         l1_gamma_diff=FLAG.lambda_m,
                         decay=FLAG.decay,
                         keep_prob=FLAG.keep_prob)

    # define tasks
    tasks = ['var_dp']
    print(tasks)

    # initial task
    cur_task = tasks[0]
    obj = vgg16.loss_dict[tasks[0]]

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=len(tasks))

    checkpoint_path = os.path.join(FLAG.save_dir, 'model.ckpt')
    tvars_trainable = tf.trainable_variables()

    #for rm in vgg16.gamma_var:
    #    tvars_trainable.remove(rm)
    #    print('%s is not trainable.'% rm)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # hyper parameters
        batch_size = 64
        epoch = 500
        early_stop_patience = 50
        min_delta = 0.0001
        opt_type = 'adam'

        # recorder
        epoch_counter = 0

        # optimizer
        global_step = tf.Variable(0, trainable=False)

        # Passing global_step to minimize() will increment it at each step.
        if opt_type is 'sgd':
            start_learning_rate = 1e-4  # adam # 4e-3 #sgd
            half_cycle = 20000
            learning_rate = tf.train.exponential_decay(start_learning_rate,
                                                       global_step,
                                                       half_cycle,
                                                       0.5,
                                                       staircase=True)
            opt = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                             momentum=0.9,
                                             use_nesterov=True)
        else:
            start_learning_rate = 1e-4  # adam # 4e-3 #sgd
            half_cycle = 10000
            learning_rate = tf.train.exponential_decay(start_learning_rate,
                                                       global_step,
                                                       half_cycle,
                                                       0.5,
                                                       staircase=True)
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)

        train_op = opt.minimize(obj,
                                global_step=global_step,
                                var_list=tvars_trainable)

        # progress bar
        ptrain = IntProgress()
        pval = IntProgress()
        display(ptrain)
        display(pval)
        ptrain.max = int(Xtrain.shape[0] / batch_size)
        pval.max = int(Xtest.shape[0] / batch_size)

        spareness = vgg16.spareness(thresh=0.05)
        print("initial spareness: %s" % sess.run(spareness))

        # re-initialize
        initialize_uninitialized(sess)

        # reset due to adding a new task
        patience_counter = 0
        current_best_val_accu = 0

        # optimize when the aggregated obj
        while (patience_counter < early_stop_patience
               and epoch_counter < epoch):

            def load_batches():
                for i in range(int(Xtrain.shape[0] / batch_size)):
                    st = i * batch_size
                    ed = (i + 1) * batch_size
                    batch = ia.Batch(images=Xtrain[st:ed, :, :, :],
                                     data=Ytrain[st:ed, :])
                    yield batch

            batch_loader = ia.BatchLoader(load_batches)
            bg_augmenter = ia.BackgroundAugmenter(batch_loader=batch_loader,
                                                  augseq=transform,
                                                  nb_workers=4)

            # start training
            stime = time.time()
            bar_train = Bar(
                'Training',
                max=int(Xtrain.shape[0] / batch_size),
                suffix='%(index)d/%(max)d - %(percent).1f%% - %(eta)ds')
            bar_val = Bar(
                'Validation',
                max=int(Xtest.shape[0] / batch_size),
                suffix='%(index)d/%(max)d - %(percent).1f%% - %(eta)ds')
            train_loss, train_accu = 0.0, 0.0
            while True:
                batch = bg_augmenter.get_batch()
                if batch is None:
                    print("Finished epoch.")
                    break
                x_images_aug = batch.images_aug
                y_images = batch.data
                loss, accu, _ = sess.run(
                    [obj, vgg16.accu_dict[cur_task], train_op],
                    feed_dict={
                        vgg16.x: x_images_aug,
                        vgg16.y: y_images,
                        vgg16.is_train: True
                    })
                bar_train.next()
                train_loss += loss
                train_accu += accu
                ptrain.value += 1
                ptrain.description = "Training %s/%s" % (ptrain.value,
                                                         ptrain.max)
            train_loss = train_loss / ptrain.value
            train_accu = train_accu / ptrain.value
            batch_loader.terminate()
            bg_augmenter.terminate()

            # # training an epoch
            # for i in range(int(Xtrain.shape[0]/batch_size)):
            #     st = i*batch_size
            #     ed = (i+1)*batch_size

            #     augX = transform.augment_images(Xtrain[st:ed,:,:,:])

            #     sess.run([train_op], feed_dict={vgg16.x: augX,
            #                                     vgg16.y: Ytrain[st:ed,:],
            #                                     vgg16.is_train: False})
            #     ptrain.value +=1
            #     ptrain.description = "Training %s/%s" % (i, ptrain.max)
            #     bar_train.next()

            # validation
            val_loss = 0
            val_accu = 0
            for i in range(int(Xtest.shape[0] / 200)):
                st = i * 200
                ed = (i + 1) * 200
                loss, accu = sess.run(
                    [obj, vgg16.accu_dict[cur_task]],
                    feed_dict={
                        vgg16.x: Xtest[st:ed, :],
                        vgg16.y: Ytest[st:ed, :],
                        vgg16.is_train: False
                    })
                val_loss += loss
                val_accu += accu
                pval.value += 1
                pval.description = "Testing %s/%s" % (pval.value, pval.value)
            val_loss = val_loss / pval.value
            val_accu = val_accu / pval.value

            print("\nspareness: %s" % sess.run(spareness))
            # early stopping check
            if (val_accu - current_best_val_accu) > min_delta:
                current_best_val_accu = val_accu
                patience_counter = 0

                para_dict = sess.run(vgg16.para_dict)
                np.save(os.path.join(FLAG.save_dir, "para_dict.npy"),
                        para_dict)
                print("save in %s" %
                      os.path.join(FLAG.save_dir, "para_dict.npy"))
            else:
                patience_counter += 1

            # shuffle Xtrain and Ytrain in the next epoch
            idx = np.random.permutation(Xtrain.shape[0])
            Xtrain, Ytrain = Xtrain[idx, :, :, :], Ytrain[idx, :]

            # epoch end
            # writer.add_summary(epoch_summary, epoch_counter)
            epoch_counter += 1

            ptrain.value = 0
            pval.value = 0
            bar_train.finish()
            bar_val.finish()

            print(
                "Epoch %s (%s), %s sec >> train loss: %.4f, train accu: %.4f, val loss: %.4f, val accu at %s: %.4f"
                % (epoch_counter, patience_counter,
                   round(time.time() - stime, 2), train_loss, train_accu,
                   val_loss, cur_task, val_accu))
        saver.save(sess, checkpoint_path, global_step=epoch_counter)

        sp, rcut = gammaSparsifyVGG16(para_dict, thresh=0.02)
        np.save(os.path.join(FLAG.save_dir, "sparse_dict.npy"), sp)
        print("sparsify %s in %s" % (np.round(
            1 - rcut, 3), os.path.join(FLAG.save_dir, "sparse_dict.npy")))

        #writer.close()
        arr_spareness.append(1 - rcut)
        np.save(os.path.join(FLAG.save_dir, "sprocess.npy"), arr_spareness)
    FLAG.optimizer = opt_type
    FLAG.lr = start_learning_rate
    FLAG.batch_size = batch_size
    FLAG.epoch_end = epoch_counter
    FLAG.val_accu = current_best_val_accu

    header = ''
    row = ''
    for key in sorted(vars(FLAG)):
        if header is '':
            header = key
            row = str(getattr(FLAG, key))
        else:
            header += "," + key
            row += "," + str(getattr(FLAG, key))
    row += "\n"
    header += "\n"
    if os.path.exists("/home/cmchang/new_CP_CNN/model.csv"):
        with open("/home/cmchang/new_CP_CNN/model.csv", "a") as myfile:
            myfile.write(row)
    else:
        with open("/home/cmchang/new_CP_CNN/model.csv", "w") as myfile:
            myfile.write(header)
            myfile.write(row)
Example #5
0
def train(FLAG):
    print("Reading dataset...")
    if FLAG.dataset == 'CIFAR-10':
        train_data = CIFAR10(train=True)
        test_data = CIFAR10(train=False)
    elif FLAG.dataset == 'CIFAR-100':
        train_data = CIFAR100(train=True)
        test_data = CIFAR100(train=False)
    else:
        raise ValueError("dataset should be either CIFAR-10 or CIFAR-100.")

    Xtrain, Ytrain = train_data.train_data, train_data.train_labels
    Xtest, Ytest = test_data.test_data, test_data.test_labels

    print("Build VGG16 models...")
    vgg16 = VGG16(FLAG.init_from, prof_type=FLAG.prof_type)

    # build model using  dp
    dp = [(i + 1) * 0.05 for i in range(1, 20)]
    vgg16.build(dp=dp)

    # define tasks
    tasks = ['100', '50']
    print(tasks)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=len(tasks))

    checkpoint_path = os.path.join(FLAG.save_dir, 'model.ckpt')

    tvars_trainable = tf.trainable_variables()
    for rm in vgg16.gamma_var:
        tvars_trainable.remove(rm)
        print('%s is not trainable.' % rm)

    # useful function
    def initialize_uninitialized(sess):
        global_vars = tf.global_variables()
        is_not_initialized = sess.run(
            [tf.is_variable_initialized(var) for var in global_vars])
        not_initialized_vars = [
            v for (v, f) in zip(global_vars, is_not_initialized) if not f
        ]
        if len(not_initialized_vars):
            sess.run(tf.variables_initializer(not_initialized_vars))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # hyper parameters
        learning_rate = 2e-4
        batch_size = 32
        alpha = 0.5
        early_stop_patience = 4
        min_delta = 0.0001

        # optimizer
        # opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        # recorder
        epoch_counter = 0

        # tensorboard writer
        writer = tf.summary.FileWriter(FLAG.log_dir, sess.graph)

        # progress bar
        ptrain = IntProgress()
        pval = IntProgress()
        display(ptrain)
        display(pval)
        ptrain.max = int(Xtrain.shape[0] / batch_size)
        pval.max = int(Xtest.shape[0] / batch_size)

        # initial task
        obj = vgg16.loss_dict[tasks[0]]

        while (len(tasks)):

            # acquire a new task
            cur_task = tasks[0]
            tasks = tasks[1:]
            new_obj = vgg16.loss_dict[cur_task]

            # just finished a task
            if epoch_counter > 0:
                # save models
                saver.save(sess, checkpoint_path, global_step=epoch_counter)

                # task-wise loss aggregation
                # obj = tf.add(tf.multiply(1-alpha,obj), tf.multiply(alpha,new_obj))
                obj = tf.add(obj, new_obj)
            # optimizer
            train_op = opt.minimize(obj, var_list=tvars_trainable)

            # re-initialize
            initialize_uninitialized(sess)

            # reset due to adding a new task
            patience_counter = 0
            current_best_val_loss = 100000  # a large number

            # optimize when the aggregated obj
            while (patience_counter < early_stop_patience):
                stime = time.time()
                bar_train = Bar(
                    'Training',
                    max=int(Xtrain.shape[0] / batch_size),
                    suffix='%(index)d/%(max)d - %(percent).1f%% - %(eta)ds')
                bar_val = Bar(
                    'Validation',
                    max=int(Xtest.shape[0] / batch_size),
                    suffix='%(index)d/%(max)d - %(percent).1f%% - %(eta)ds')

                # training an epoch
                for i in range(int(Xtrain.shape[0] / batch_size)):
                    st = i * batch_size
                    ed = (i + 1) * batch_size
                    sess.run([train_op],
                             feed_dict={
                                 vgg16.x: Xtrain[st:ed, :, :, :],
                                 vgg16.y: Ytrain[st:ed, :]
                             })
                    ptrain.value += 1
                    ptrain.description = "Training %s/%s" % (i, ptrain.max)
                    bar_train.next()

                # validation
                val_loss = 0
                val_accu = 0
                for i in range(int(Xtest.shape[0] / 200)):
                    st = i * 200
                    ed = (i + 1) * 200
                    loss, accu, epoch_summary = sess.run(
                        [obj, vgg16.accu_dict[cur_task], vgg16.summary_op],
                        feed_dict={
                            vgg16.x: Xtest[st:ed, :],
                            vgg16.y: Ytest[st:ed, :]
                        })
                    val_loss += loss
                    val_accu += accu
                    pval.value += 1
                    pval.description = "Testing %s/%s" % (i, pval.value)
                val_loss = val_loss / pval.value
                val_accu = val_accu / pval.value

                # early stopping check
                if (current_best_val_loss - val_loss) > min_delta:
                    current_best_val_loss = val_loss
                    patience_counter = 0
                else:
                    patience_counter += 1

                # shuffle Xtrain and Ytrain in the next epoch
                idx = np.random.permutation(Xtrain.shape[0])
                Xtrain, Ytrain = Xtrain[idx, :, :, :], Ytrain[idx, :]

                # epoch end
                writer.add_summary(epoch_summary, epoch_counter)
                epoch_counter += 1

                ptrain.value = 0
                pval.value = 0
                bar_train.finish()
                bar_val.finish()

                print(
                    "Epoch %s (%s), %s sec >> obj loss: %.4f, task at %s: %.4f"
                    % (epoch_counter, patience_counter,
                       round(time.time() - stime,
                             2), val_loss, cur_task, val_accu))
        saver.save(sess, checkpoint_path, global_step=epoch_counter)

        writer.close()
Example #6
0
def test(FLAG):
    print("Reading dataset...")
    if FLAG.dataset == 'CIFAR-10':
        test_data = CIFAR10(train=False)
        vgg16 = VGG16(classes=10)
    elif FLAG.dataset == 'CIFAR-100':
        test_data = CIFAR100(train=False)
        vgg16 = VGG16(classes=100)
    else:
        raise ValueError("dataset should be either CIFAR-10 or CIFAR-100.")

    Xtest, Ytest = test_data.test_data, test_data.test_labels

    if FLAG.fidelity is not None:
        data_dict = np.load(FLAG.init_from, encoding='latin1').item()
        data_dict = dpSparsifyVGG16(data_dict, FLAG.fidelity)
        vgg16.build(vgg16_npy_path=data_dict,
                    prof_type=FLAG.prof_type,
                    conv_pre_training=True,
                    fc_pre_training=True)
        print("Build model from %s using dp=%s" %
              (FLAG.init_from, str(FLAG.fidelity * 100)))
    else:
        vgg16.build(vgg16_npy_path=FLAG.init_from,
                    prof_type=FLAG.prof_type,
                    conv_pre_training=True,
                    fc_pre_training=True)
        print("Build full model from %s" % (FLAG.init_from))

    # build model using  dp
    dp = [(i + 1) * 0.05 for i in range(1, 20)]
    vgg16.set_idp_operation(dp=dp, keep_prob=FLAG.keep_prob)

    flops, params = countFlopsParas(vgg16)
    print("Flops: %3f M, Paras: %3f M" % (flops / 1e6, params / 1e6))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.global_variables())
        print("Initialized")
        output = []
        for dp_i in dp:
            accu = sess.run(vgg16.accu_dict[str(int(dp_i * 100))],
                            feed_dict={
                                vgg16.x: Xtest[:5000, :],
                                vgg16.y: Ytest[:5000, :],
                                vgg16.is_train: False
                            })
            accu2 = sess.run(vgg16.accu_dict[str(int(dp_i * 100))],
                             feed_dict={
                                 vgg16.x: Xtest[5000:, :],
                                 vgg16.y: Ytest[5000:, :],
                                 vgg16.is_train: False
                             })
            output.append((accu + accu2) / 2)
            print("At DP={dp:.4f}, accu={perf:.4f}".format(
                dp=dp_i, perf=(accu + accu2) / 2))
        res = pd.DataFrame.from_dict({
            'DP': [int(dp_i * 100) for dp_i in dp],
            'accu': output
        })
        res.to_csv(FLAG.output, index=False)
        print("Write into %s" % FLAG.output)
Example #7
0
def trainval(model, save_dir, cfg, resume_epoch=None):
    train_dataset = CIFAR100(data_dir, augmentation=cfg.TRAIN.AUGMENTATION)
    test_dataset = CIFAR100(data_dir, is_test=True, augmentation=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.TRAIN.BATCH_SIZE,
                              shuffle=True,
                              num_workers=40)

    cretirion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=cfg.TRAIN.LR,
                          momentum=0.9,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    lrScheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.DECAY_EPOCH, gamma=cfg.TRAIN.LR_DECAY_GAMMA)
    logdir = os.path.join(cfg.LOG_DIR, save_dir.split('/')[-1])
    summaryWriter = SummaryWriter(logdir)
    summaryWriter.add_graph(model, torch.zeros(1, 3, 32, 32))

    model.cuda()
    model = nn.DataParallel(model)
    start_epoch = -1
    iter_counter = 0
    if resume_epoch is not None:
        load_point = os.path.join(save_dir,
                                  'model_{}.pth'.format(resume_epoch))
        saved_dict = torch.load(load_point)
        model.load_state_dict(saved_dict['state_dict'])
        start_epoch = saved_dict['epoch']
        iter_counter = saved_dict['iter_counter']
    for epoch in range(0, start_epoch + 1):
        lrScheduler.step()
    print('Start at epoch %d' % (start_epoch + 1))

    for epoch in range(start_epoch + 1, cfg.TRAIN.EPOCH):
        for (images,
             labels) in tqdm(train_loader,
                             desc='Epoch {}/{}'.format(epoch + 1,
                                                       cfg.TRAIN.EPOCH)):
            images = images.cuda()
            labels = labels.cuda()
            optimizer.zero_grad()
            outputs = model(images)
            loss = cretirion(outputs, labels)
            summaryWriter.add_scalar('train/loss',
                                     loss.cpu().item(), iter_counter)
            iter_counter += 1
            loss.backward()
            optimizer.step()
        lrScheduler.step()
        if epoch % cfg.VALID_STEP == (cfg.VALID_STEP - 1):
            model.eval()
            top1_acc, top5_acc = valid(model, test_dataset, cfg)
            model.train()
            summaryWriter.add_scalar('valid/top1', top1_acc, epoch + 1)
            summaryWriter.add_scalar('valid/top5', top5_acc, epoch + 1)

        if epoch % cfg.SAVE_STEP == (cfg.SAVE_STEP - 1):
            save_dict = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'iter_counter': iter_counter
            }
            filename = os.path.join(save_dir, 'model_{}.pth'.format(epoch))
            torch.save(save_dict, filename)
    return model