def train_neural_network(x_inpuT, y_inpuT, labels_path, val_labels_path,
                         save_loss_path, save_model_path, batch_size,
                         val_batch_size, image_height, image_width,
                         learning_rate, weight_decay, num_iter, epochs,
                         which_model, num_train_videos, num_val_videos,
                         win_size):

    with tf.name_scope("cross_entropy"):

        prediction = 0
        shapes_list = None
        if which_model == 1:
            prediction, shapes_list = conv3d1.inference(x_inpuT)

        elif which_model == 2:
            resnext_model = ResNeXt(x_inpuT, tf.constant(True, dtype=tf.bool))
            prediction = resnext_model.Build_ResNext(x_inpuT)
        elif which_model == 3:
            prediction = conv3d2.inference(x_inpuT)
        elif which_model == 4:
            prediction = conv3d3.inference(x_inpuT)
        elif which_model == 5:
            prediction, shapes_list = conv3d4.inference(x_inpuT)

        cost = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=prediction,
                                                    labels=y_inpuT))

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):

        # optimizer = 0
        if weight_decay is not None:
            print("weight decay applied.")
            optimizer = create_optimizer(cost, learning_rate, weight_decay)
        else:
            optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

    with tf.name_scope("accuracy"):
        correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y_inpuT, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

    # saver = tf.train.Saver(save_relative_paths=True)
    print("Calculating total parameters in the model")
    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        print(shape)
        # print(len(shape))
        variable_parameters = 1
        for dim in shape:
            # print(dim)
            variable_parameters *= dim.value
        # print(variable_parameters)
        total_parameters += variable_parameters
    print(total_parameters)
    if shapes_list is not None:
        print(shapes_list)

    gpu_options = tf.GPUOptions(allow_growth=True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:

        print("session starts!")

        sess.run(tf.global_variables_initializer())

        start_time = time.time()
        epoch_loss_list = []
        val_loss_list = []

        ori_height = 240
        ori_width = 320

        with open(labels_path, 'r') as f:
            lines = f.readlines()

        with open(val_labels_path, 'r') as f:
            val_lines = f.readlines()

        for epoch in range(epochs):

            print("Epoch {} started!".format(epoch + 1))
            epoch_start_time = time.time()

            epoch_loss = 0
            train_acc = 0

            window_size = win_size

            random.seed(7)
            print("Random seed fixed for training.")

            num_videos = num_train_videos
            num_batches = int(num_videos / batch_size)
            num_iter_total_data_per_epoch = num_iter

            for iter_index in range(num_iter_total_data_per_epoch):

                prev_line = -1
                num_line = 0
                num_video_in_curr_line = 0

                # for batch_num in range(num_batches):

                batch_x = np.zeros(
                    (batch_size, depth, ori_height, ori_width, 3))
                batch_y = np.zeros((batch_size, num_classes))
                batch_index = 0
                num_batch_completed = 0

                while True:
                    if prev_line != num_line:
                        curr_line = lines[num_line]
                        par_name, vid_name, label_info = extract_line_info(
                            curr_line)
                        total_video_in_curr_line = len(label_info)

                    final_video, label = get_final_video(
                        par_name,
                        vid_name,
                        label_info,
                        num_video_in_curr_line,
                        window_size,
                        set='train')
                    # print(final_video.shape)

                    batch_x[batch_index, :, :, :, :] = final_video

                    basic_line = [0] * num_classes
                    basic_line[int(label) - 1] = 1
                    basic_label = basic_line
                    batch_y[batch_index, :] = np.array(basic_label)
                    batch_index += 1

                    if batch_index == batch_size:
                        # train_batch
                        batch_start_time = time.time()

                        mini_batch_x = data_augmentation(
                            batch_x, (image_height, image_width))
                        # mini_batch_x = mini_batch_x / 255.0
                        mini_batch_y = batch_y

                        perm = np.random.permutation(batch_size)
                        mini_batch_x = mini_batch_x[perm]
                        mini_batch_y = mini_batch_y[perm]

                        _optimizer, _cost, _prediction, _accuracy = sess.run(
                            [optimizer, cost, prediction, accuracy],
                            feed_dict={
                                x_inpuT: mini_batch_x,
                                y_inpuT: mini_batch_y
                            })
                        epoch_loss += _cost
                        train_acc += _accuracy

                        num_batch_completed += 1
                        batch_end_time = time.time()

                        total_train_batch_completed = iter_index * num_batches + num_batch_completed

                        log1 = "\rEpoch: {}, " \
                               "iter: {}, " \
                               "batches completed: {}, " \
                               "time taken: {:.5f}, " \
                               "loss: {:.6f}, " \
                               "accuracy: {:.4f} \n". \
                            format(
                            epoch + 1,
                            iter_index + 1,
                            total_train_batch_completed,
                            batch_end_time - batch_start_time,
                            epoch_loss / (batch_size * total_train_batch_completed),
                            _accuracy)

                        print(log1)
                        sys.stdout.flush()

                        if num_batch_completed == num_batches:
                            break

                        batch_index = 0
                        batch_x = np.zeros(
                            (batch_size, depth, ori_height, ori_width, 3))
                        batch_y = np.zeros((batch_size, num_classes))

                    prev_line = num_line
                    num_video_in_curr_line += 1
                    if num_video_in_curr_line == total_video_in_curr_line:
                        num_video_in_curr_line = 0
                        num_line += 1

            # validation loss
            print("<---------------- Validation Set started ---------------->")
            val_loss = 0
            val_acc = 0

            val_num_videos = num_val_videos
            val_num_batches = int(val_num_videos / val_batch_size)

            random.seed(23)
            print("Random seed fixed for validation.")

            for __ in range(num_iter_total_data_per_epoch):
                prev_line = -1
                num_line = 0
                num_video_in_curr_line = 0

                # for batch_num in range(val_num_batches):
                #
                val_batch_x = np.zeros(
                    (val_batch_size, depth, ori_height, ori_width, 3))
                val_batch_y = np.zeros((val_batch_size, num_classes))
                batch_index = 0
                val_num_batch_completed = 0

                while True:
                    if prev_line != num_line:
                        curr_line = val_lines[num_line]
                        par_name, vid_name, label_info = extract_line_info(
                            curr_line)
                        total_video_in_curr_line = len(label_info)

                    # process video
                    final_video, label = get_final_video(
                        par_name,
                        vid_name,
                        label_info,
                        num_video_in_curr_line,
                        window_size,
                        set='valid')
                    # print(final_video.shape)

                    val_batch_x[batch_index, :, :, :, :] = final_video

                    basic_line = [0] * num_classes
                    basic_line[int(label) - 1] = 1
                    basic_label = basic_line
                    val_batch_y[batch_index, :] = np.array(basic_label)
                    batch_index += 1

                    if batch_index == val_batch_size:
                        val_batch_x = data_augmentation(
                            val_batch_x, (image_height, image_width))
                        # val_batch_x = val_batch_x / 255.0

                        perm = np.random.permutation(batch_size)
                        val_batch_x = val_batch_x[perm]
                        val_batch_y = val_batch_y[perm]

                        val_cost, val_batch_accuracy = sess.run(
                            [cost, accuracy],
                            feed_dict={
                                x_inpuT: val_batch_x,
                                y_inpuT: val_batch_y
                            })

                        val_acc += val_batch_accuracy
                        val_loss += val_cost

                        val_num_batch_completed += 1

                        if val_num_batch_completed == val_num_batches:
                            break

                        val_batch_x = np.zeros(
                            (val_batch_size, depth, ori_height, ori_width, 3))
                        val_batch_y = np.zeros((val_batch_size, num_classes))
                        batch_index = 0

                    prev_line = num_line
                    num_video_in_curr_line += 1
                    if num_video_in_curr_line == total_video_in_curr_line:
                        num_video_in_curr_line = 0
                        num_line += 1

            epoch_end_time = time.time()

            total_train_batch_completed = num_iter_total_data_per_epoch * num_batch_completed
            total_val_num_batch_completed = num_iter_total_data_per_epoch * val_num_batch_completed

            epoch_loss = epoch_loss / (batch_size *
                                       total_train_batch_completed)
            train_acc = train_acc / total_train_batch_completed

            val_loss /= (val_batch_size * total_val_num_batch_completed)
            val_acc = val_acc / total_val_num_batch_completed

            log3 = "Epoch {} done; " \
                   "Time Taken: {:.4f}s; " \
                   "Train_loss: {:.6f}; " \
                   "Val_loss: {:.6f}; " \
                   "Train_acc: {:.4f}; " \
                   "Val_acc: {:.4f}; " \
                   "Train batches: {}; " \
                   "Val batches: {};\n". \
                format(epoch + 1, epoch_end_time - epoch_start_time, epoch_loss, val_loss, train_acc, val_acc,
                       num_iter_total_data_per_epoch * num_batch_completed,
                       num_iter_total_data_per_epoch * val_num_batch_completed)

            print(log3)

            if save_loss_path is not None:
                file1 = open(save_loss_path, "a")
                file1.write(log3)
                file1.close()

            epoch_loss_list.append(epoch_loss)
            val_loss_list.append(val_loss)

            if save_model_path is not None:
                saver.save(sess, save_model_path)

        end_time = time.time()
        print('Time elapse: ', str(end_time - start_time))
        print(epoch_loss_list)

        if save_loss_path is not None:
            file1 = open(save_loss_path, "a")
            file1.write("Train Loss List: {} \n".format(str(epoch_loss_list)))
            file1.write("Val Loss List: {} \n".format(str(val_loss_list)))
            file1.close()
Exemplo n.º 2
0
else:
    input_shape = (img_width, img_height, 1)

# the data shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
nb_train_samples = x_train.shape[0]
nb_validation_samples = x_test.shape[0]

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# load model
model = ResNeXt(input_shape=input_shape,
                n_class=num_classes,
                weight_decay=weight_decay,
                batch_size=batch_size,
                cardinality=cardinality).model()

optimizer = SGD(lr=init_lr, momentum=momentum, nesterov=True)
model.compile(loss='categorical_crossentropy',
              optimizer=optimizer,
              metrics=['accuracy'])
# model.summary()

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255.
x_test /= 255.
x_train = x_train.reshape(x_train.shape[0], img_height, img_width, channels)
x_test = x_test.reshape(x_test.shape[0], img_height, img_width, channels)
Exemplo n.º 3
0
    'DenseNetBC100':
    lambda args: L.Classifier(DenseNetBC(args.nclasses, (16, 16, 16), 12)),
    'DenseNetBC100_DConv':
    lambda args: L.Classifier(DenseNetBC_DConv(args.nclasses,
                                               (16, 16, 16), 12)),
    'DenseNetBC100_PGP':
    lambda args: FuseTrainWrapper(
        DenseNetBC_PGP(args.nclasses, (16, 16, 16), 12)),
    'WideResNet28-10':
    lambda args: L.Classifier(WideResNet(28, args.nclasses, 10)),
    'WideResNet28-10_DConv':
    lambda args: L.Classifier(WideResNet_DConv(28, args.nclasses, 10)),
    'WideResNet28-10_PGP':
    lambda args: FuseTrainWrapper(WideResNet_PGP(28, args.nclasses, 10)),
    'ResNeXt29_8x64d':
    lambda args: L.Classifier(ResNeXt(29, args.nclasses)),
    'ResNeXt29_8x64d_DConv':
    lambda args: L.Classifier(ResNeXt_DConv(29, args.nclasses)),
    'ResNeXt29_8x64d_PGP':
    lambda args: FuseTrainWrapper(ResNeXt_PGP(29, args.nclasses)),
    'PyramidNetB164':
    lambda args: L.Classifier(PyramidNet(164, args.nclasses)),
    'PyramidNetB164_DConv':
    lambda args: L.Classifier(PyramidNet_DConv(164, args.nclasses)),
    'PyramidNetB164_PGP':
    lambda args: FuseTrainWrapper(PyramidNet_PGP(164, args.nclasses)),
    'Shake-Shake26_2x64d':
    lambda args: L.Classifier(ShakeShake(26, args.nclasses, k=64)),
}

Exemplo n.º 4
0
model_name = 'keras_cifar10_trained_model.h5'

# The data, shuffled and split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)

net = ResNeXt(input_shape=x_train.shape[1:], num_class=10, cardinality=8, depth=29, 
                    base_width=64)
model = net.build()

# Let's train the model using RMSprop
model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

if not data_augmentation:
    print('Not using data augmentation.')
    model.fit(x_train, y_train,
Exemplo n.º 5
0
    'PreResNet164_PGP': lambda args: FuseTrainWrapper(
        PreResNet_PGP(164, args.nclasses)),
    'DenseNetBC100': lambda args: L.Classifier(
        DenseNetBC(args.nclasses, (16, 16, 16), 12)),
    'DenseNetBC100_DConv': lambda args: L.Classifier(
        DenseNetBC_DConv(args.nclasses, (16, 16, 16), 12)),
    'DenseNetBC100_PGP': lambda args: FuseTrainWrapper(
        DenseNetBC_PGP(args.nclasses, (16, 16, 16), 12)),
    'WideResNet28-10': lambda args: L.Classifier(
        WideResNet(28, args.nclasses, 10)),
    'WideResNet28-10_DConv': lambda args: L.Classifier(
        WideResNet_DConv(28, args.nclasses, 10)),
    'WideResNet28-10_PGP': lambda args: FuseTrainWrapper(
        WideResNet_PGP(28, args.nclasses, 10)),
    'ResNeXt29_8x64d': lambda args: L.Classifier(
        ResNeXt(29, args.nclasses)),
    'ResNeXt29_8x64d_DConv': lambda args: L.Classifier(
        ResNeXt_DConv(29, args.nclasses)),
    'ResNeXt29_8x64d_PGP': lambda args: FuseTrainWrapper(
        ResNeXt_PGP(29, args.nclasses)),
    'PyramidNetB164': lambda args: L.Classifier(
        PyramidNet(164, args.nclasses)),
    'PyramidNetB164_DConv': lambda args: L.Classifier(
        PyramidNet_DConv(164, args.nclasses)),
    'PyramidNetB164_PGP': lambda args: FuseTrainWrapper(
        PyramidNet_PGP(164, args.nclasses)),
    'Shake-Shake26_2x64d': lambda args: L.Classifier(
        ShakeShake(26, args.nclasses, k=64)),
}

Exemplo n.º 6
0
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# /////////////// Model Setup ///////////////

# Create model

if args.model == 'resnext':
    from models.resnext import ResNeXt
    net = ResNeXt({
        'input_shape': (1, 3, 32, 32),
        'n_classes': num_classes,
        'base_channels': 32,
        'depth': 29,
        'cardinality': 8
    })
if 'shake' in args.model:
    from models.shake_shake import ResNeXt
    net = ResNeXt({
        'input_shape': (1, 3, 32, 32),
        'n_classes': num_classes,
        'base_channels': 96,
        'depth': 26,
        "shake_forward": True,
        "shake_backward": True,
        "shake_image": True
    })
    args.epochs = 500