Esempio n. 1
0
    def _load(self, pth_path: str):
        """Loading model.

        Args:
            pth_path (str): pth path.

        Raises:
            FileNotFoundError: file not exists.
        """
        # check exist
        if not os.path.exists(pth_path):
            err = OSError(errno.ENOENT, os.strerror(errno.ENOENT), pth_path)
            raise FileNotFoundError(err)

        # load checkpoint
        checkpoint = torch.load(pth_path)

        # classes, network
        self.classes = checkpoint["classes"]
        self.classify_size = len(self.classes)

        # network
        self.net = cnn.Net(
            input_size=self.input_size,
            classify_size=self.classify_size,
            in_channels=self.in_channels,
        )
        self.net.load_state_dict(checkpoint["model_state_dict"])

        self.net.to(self.device)  # switch to GPU / CPU
        self.net.eval()  # switch to eval
    def _build_model(self):
        self.net = cnn.Net(input_size=self.input_size)  # network

        # self.optimizer = optim.SGD(self.net.parameters(), lr=0.01)
        self.optimizer = optim.Adam(self.net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8)
        self.criterion = nn.CrossEntropyLoss()

        self.net.zero_grad()  # init all gradient
        self.net.to(self.device)  # switch to GPU / CPU
Esempio n. 3
0
def test(test_set, PATH):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    classes = ('0', '1')

    batch_size = 4
    num_workers = 2

    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers)

    net = cnn.Net().to(device)
    net.load_state_dict(torch.load(PATH))

    correct = 0
    total = 0
    n_class_correct = [0 for i in range(2)]
    n_class_samples = [0 for i in range(2)]
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():

        epoch_loss = 0.0

        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)
            loss = criterion(outputs, labels)
            epoch_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for l in range(batch_size):
                label = labels[l]
                pred = predicted[l]
                if (label == pred):
                    n_class_correct[label] += 1
                n_class_samples[label] += 1

    for m in range(2):
        acc = 100 * n_class_correct[m] / n_class_samples[m]
        print('Accuracy of %s: %.3f %%' % (classes[m], acc))

    accuracy = 100 * correct / total
    loss_t = 100 * epoch_loss / i

    print('Accuracy of the network on the 10000 test images: %.3f %%' %
          (accuracy))
    print('Loss of the network on the 10000 test images: %.3f %%' % (loss_t))

    return loss_t, accuracy
Esempio n. 4
0
def train_model():
    dataloader = load_dataset(train=True)
    model = cnn.Net()
    model.to(cnn.get_device())
    model.train()
    model, dictionary = cnn.train_model(model=model,
                                        dataloader=dataloader,
                                        num_epochs=10)
    print("Time spent: {:.2f}s".format(dictionary['exec_time']))
    torch.save(model.state_dict(), "./modelo/mnist.pt")
    save_stats(dictionary)
 def __init__(self,
              prototxt_name,
              model_name,
              batch_size=1,
              max_num_proposals=60,
              output_dim=201,
              iou_thres=0.3):
     self.Net = cnn.Net(prototxt_name, model_name, batch_size)
     self.proposer = region_proposer.proposer(max_num_proposals)
     self.nms = nms.reducer(iou_thres)
     self.output_dim = output_dim
Esempio n. 6
0
def test_model():
    dataloader = load_dataset(train=False)
    model = cnn.Net()
    model.load_state_dict(torch.load("./modelo/mnist.pt"))
    model.to(cnn.get_device())
    model.eval()
    dictionary = cnn.test_model(model=model, dataloader=dataloader)
    classes = list(dataloader.dataset.class_to_idx.keys())
    dictionary['classes'] = classes
    dictionary = {**load_stats(), **dictionary}
    save_stats(dictionary)
    show_stats(dictionary)
def train(train_set):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    batch_size = 4
    num_workers = 2
    learning_rate = 0.001
    momentum = 0.9
    num_epochs = 2

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)

    net = cnn.Net().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=learning_rate,
                          momentum=momentum)

    for epoch in range(num_epochs):

        for i, (inputs, labels) in enumerate(train_loader):

            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    epoch_loss = 0.0

    for j, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = net(inputs)
        loss = criterion(outputs, labels)

        epoch_loss += loss.item()

    print('Finished Training')

    torch.save(net.state_dict(), './mnist_net.pth')
    print('Loss of the network on the train images: %.3f' % (epoch_loss /
                                                             (j + 1)))
    return epoch_loss / (j + 1)
Esempio n. 8
0
    def __init__(self):

        self.test_on_gpu = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.test_on_gpu else "cpu")
        self.model = cnn.Net()
        self.model.to(self.device)
        print(self.device)
        self.model.double()
        self.model.load_state_dict(torch.load("./src/construction_site_lane_detection/models/nn_model.pt"))
        self.model.eval()
        self.state = 1 #there is a red truck in front
        self.frame_count = 0
Esempio n. 9
0
def assert_image(dataloader, index):
    model = cnn.Net()
    model.load_state_dict(torch.load("./modelo/mnist.pt"))
    model.to(cnn.get_device())
    model.eval()

    image = dataloader.dataset[index][0]
    image = image.to(cnn.get_device())
    image = image[None]
    image = image.type('torch.FloatTensor')
    predictated = cnn.assert_image(model, image)
    class_map = dict(map(reversed, dataloader.dataset.class_to_idx.items()))
    print(class_map[predictated])
    show_image(dataloader, index)
    return predictated
Esempio n. 10
0
import numpy as np
import cnn
import torch.optim as optim
import torch.nn as nn
import sklearn.model_selection as skms
import torch

print("===============Loading data...================")
x_train, x_test, y_train, y_test = Utl.load_data22_from_all()
print("x_train shape: ", x_train.shape)
print("x_test  shape: ", x_test.shape)
print("y_train shape: ", y_train.shape)
print("y_test  shape: ", y_test.shape)
print("===============Training model...==============")

net = cnn.Net()
net.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001, weight_decay=0.01)

for epoch in range(1):  # loop over the dataset multiple times

    running_loss = 0.0
    for i in range(2000):
        inputs, labels = Utl.mini_batch(x_train, y_train, batch_size=200)
        labels = Utl.index_code(labels)
        inputs, labels = torch.from_numpy(inputs), torch.from_numpy(labels)
        inputs = torch.unsqueeze(inputs, 1)
        inputs = inputs.float()
        labels = labels.long()
        inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
Esempio n. 11
0
def fit(F,
        D,
        sr,
        time_win_sec,
        log_dir=os.getcwd(),
        log_id='id0',
        O={},
        overwrite=False,
        plot_figures=True,
        sort_weights_PC1=False,
        sort_data_PC1=False):

    # default optional parameters
    P = {}
    P['rank'] = None
    P['weight_scale'] = 1e-6
    P['learning_rate'] = 0.001
    P['early_stopping_steps'] = 10
    P['train_val_test_frac'] = [0.6, 0.2, 0.2]
    P['eval_interval'] = 5
    P['max_iter'] = 1000
    P['act'] = 'relu'
    P['seed'] = 0
    P['stim_labels'] = None
    P['print_iter'] = True

    # replace defaults with optional parameters
    P = misc.param_handling(
        P,
        O,
        maxlen=100,
        delimiter='/',
        omit_key_from_idstring=['stim_labels', 'print_iter'])

    # feature dimensionality
    n_stims = F.shape[0]
    n_tps = F.shape[1]
    n_feats = F.shape[2]

    # data dimensionality
    if len(D.shape) == 2:
        D = np.reshape(D, [n_stims, n_tps, 1])
    n_resp = D.shape[2]
    data_dims = [n_stims, n_tps, n_resp]
    assert (D.shape[0] == n_stims)
    assert (D.shape[1] == n_tps)

    # create train, validation, test splits
    if P['stim_labels'] is None:
        labels = np.zeros((n_stims, 1))
    else:
        labels = P['stim_labels']
    train_val_test = misc.partition_within_labels(labels,
                                                  P['train_val_test_frac'],
                                                  seed=P['seed'])

    # directory to save results
    idstring = 'win-' + misc.num2str(time_win_sec) + '_sr-' + misc.num2str(sr)
    if P['idstring']:
        idstring = idstring + '_' + P['idstring']
    save_directory = misc.mkdir(log_dir + '/' + idstring)

    # file with the key stats
    stats_file = save_directory + '/stats_' + log_id + '.p'

    if os.path.exists(stats_file) and not (overwrite):

        S = pickle.load(open(stats_file, "rb"))

    else:

        # create single layer CNN with above parameters
        layer = {}
        layer['type'] = 'conv'
        layer['n_kern'] = n_resp
        layer['time_win_sec'] = time_win_sec
        layer['act'] = P['act']
        layer['rank'] = P['rank']
        layers = []
        layers.append(layer)

        # initialize, build, and train
        tf.reset_default_graph()
        n_weights = (time_win_sec * sr * n_feats)
        net = cnn.Net(data_dims,
                      n_feats,
                      sr,
                      deepcopy(layers),
                      loss_type='squared_error',
                      weight_scale=P['weight_scale'] / n_weights,
                      seed=P['seed'],
                      log_dir=save_directory,
                      log_id=log_id)
        net.build()
        net.train(F,
                  D,
                  max_iter=P['max_iter'],
                  eval_interval=P['eval_interval'],
                  learning_rate=P['learning_rate'],
                  train_val_test=train_val_test,
                  early_stopping_steps=P['early_stopping_steps'],
                  print_iter=P['print_iter'])

        S = {}
        S['train_loss'] = net.train_loss
        S['val_loss'] = net.val_loss
        S['test_loss'] = net.test_loss
        S['W'] = net.layer_vals()[0]['W']
        S['Y'] = net.predict(F)
        S['train_val_test'] = train_val_test

        pickle.dump(S, open(stats_file, "wb"))

    print('Train loss:', S['train_loss'][-1])
    print('Val loss:', S['val_loss'][-1])
    print('Test loss:', S['test_loss'][-1])

    S['test_corr'] = np.corrcoef(
        S['Y'][S['train_val_test'] == 2, :, 0].flatten(),
        D[S['train_val_test'] == 2, :, 0].flatten())[0, 1]

    if plot_figures:

        # loss
        plt.plot(S['train_loss'])
        plt.plot(S['val_loss'])
        plt.plot(S['test_loss'])
        plt.legend(['Train', 'Val', 'Test'])
        plt.xlabel('Eval Iter')
        plt.ylabel('Loss')
        plt.savefig(save_directory + '/loss_' + log_id + '.pdf',
                    bbox_inches='tight')
        plt.show()

        # predictions
        xi = np.where(S['train_val_test'] == 2)[0]
        if sort_data_PC1:
            [U, E, V] = np.linalg.svd(np.transpose(D[xi, :, 0]))
            feat_weights = V[0, :] * np.sign(U[:, 0].mean())
            xi = xi[np.flipud(np.argsort(feat_weights))]
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plot.imshow(D[xi, :, 0])
        plt.title('Data')
        plt.ylabel('Test stims')
        plt.xlabel('Time')
        plt.subplot(1, 2, 2)
        plot.imshow(S['Y'][xi, :, 0])
        plt.title('Prediction')
        plt.ylabel('Test stims')
        plt.xlabel('Time')
        del xi
        plt.savefig(save_directory + '/predictions_' + log_id + '.pdf',
                    bbox_inches='tight')
        plt.show()

        # plot weights, optionally sort by first PC
        if sort_weights_PC1:
            [U, E, V] = np.linalg.svd(S['W'][:, :, 0])
            feat_weights = V[0, :] * np.sign(U[:, 0].mean())
            xi = np.flipud(np.argsort(feat_weights))
        else:
            xi = np.arange(0, n_feats)
        plt.figure(figsize=(5, 5))
        plot.imshow(np.fliplr(np.transpose(S['W'][:, xi, 0])))
        plt.title('Weights')
        plt.xlabel('Time')
        plt.ylabel('Feats')
        del xi
        plt.savefig(save_directory + '/weights_' + log_id + '.pdf',
                    bbox_inches='tight')
        plt.show()

    return S
Esempio n. 12
0
        fig = plt.figure(figsize=(10, 4))
        # display 20 images
        for idx in np.arange(2):
            ax = fig.add_subplot(2, 10 / 2, idx + 1, xticks=[], yticks=[])
            imshow(images[idx])
            ax.set_title(classes[labels[idx]])
        plt.show()

    criterion = torch.nn.CrossEntropyLoss()

    # track test loss
    test_loss = 0.0
    class_correct = list(0. for i in range(2))
    class_total = list(0. for i in range(2))

    model = cnn.Net()
    model.load_state_dict(torch.load("model_cifar.pt"))
    if train_on_gpu:
        model.cuda()
    model.eval()

    i = 1
    # iterate over test data
    #print(len(test_loader)) #how many batches
    for data, target in test_loader:
        print("target", target)
        i = i + 1
        if len(target) != batch_size:
            continue

        # move tensors to GPU if CUDA is available
Esempio n. 13
0
def test(arguments):
    # create settings from arguments
    settings = {
        'data_folder': Path(arguments[1]),
        'model_name': arguments[2],
        'masked': ('masked' in arguments),
        'scrolling': ('scrolling' in arguments),
        'root': ('root' in arguments),
        'overlap': ('overlap' in arguments)
    }

    # filter users and webpages to use
    users = names.get_all_users()
    users_to_use = list(set(users) & set(arguments))
    if users_to_use:
        users = users_to_use
    else:
        users = users[:1]

    webpages = names.get_all_webpages()
    webpages_to_use = list(set(webpages) & set(arguments))
    if webpages_to_use:
        webpages = webpages_to_use

    print('Loading Model')
    net = cnn.Net()
    state_file = settings['model_name'] + '.pth'
    net.load_state_dict(torch.load(settings['data_folder'] / state_file))
    print('Model loaded successfully')

    # print settings
    printer.print_settings(settings,
                           users,
                           users_to_use,
                           webpages,
                           webpages_to_use,
                           training=False)

    # create dict from CSV files
    csv_dict = dict_creator.create_dict(users, webpages, settings)

    # check if the observations are getting ORed or are separate by layers and their masks
    if settings['masked'] or settings['scrolling'] or settings['root']:
        all_obs = [
            ob for user in users for webpage in webpages
            for ob in csv_dict[user][webpage]['features_meta']
        ]
    else:
        all_obs = [
            ob for user in users for webpage in webpages
            for ob in csv_dict[user][webpage]['labeled_obs']
        ]

    if settings['overlap']:
        all_obs = list(
            filter(
                lambda o: (int(o['overlap_height']) > 32) and
                (int(o['overlap_width']) > 32), all_obs))

    # split obs into training data and testing data
    # random_seed is used so test_cnn.py will have the same list to work with
    split_percentage = int(
        next((s for s in arguments if 'split' in s), 'split=0').split('=')[-1])
    if split_percentage:
        if split_percentage > 99 or split_percentage < 1:
            print('Percentage has to be between 0 and 100, splitting aborted')
        else:
            print('Splitting data:')
            print(str(len(all_obs)) + ' observations total')
            random_seed = next((s for s in arguments if 'seed' in s),
                               'seed=' + settings['model_name']).split('=')[-1]
            random.Random(random_seed).shuffle(all_obs)
            all_obs = all_obs[int(len(all_obs) * (split_percentage) / 100):]
            print(str(len(all_obs)) + ' observations kept')

    vids = {}
    for user in users:
        vids[user] = {}
        for webpage in webpages:
            vids[user][webpage] = cv2.VideoCapture(
                get_video_file(user, webpage, settings))

    predictions = []
    net.eval()
    labeled = 0
    print('Observations to test: ' + str(len(all_obs)))
    print('Testing started!')
    i = 0
    with torch.no_grad():
        for ob in all_obs:

            # prepare input
            input_frame = imp.get_merged_frame_pair(
                vids[ob['user']][ob['webpage']], ob, settings)
            # swap color axis because
            # numpy image: H x W x C
            # torch image: C X H X W
            input_frame = torch.from_numpy(input_frame.transpose((2, 0, 1)))
            inputs = [input_frame]
            target = min(1, ob['label'])
            labeled += target

            # get prediction
            inputs = torch.stack(inputs, dim=0)
            outputs = net(inputs)
            outputs = torch.sigmoid(outputs)
            output_labels = torch.round(outputs)
            prediction = output_labels.numpy().squeeze().tolist()
            not_rounded = outputs.numpy().squeeze().tolist()

            # fill prediction list
            predictions.append({
                'target': target,
                'prediction': prediction,
                'not_rounded': not_rounded,
                'webpage': ob['webpage'],
                'user': ob['user']
            })

            # print statistics
            i += 1
            if i % 200 == 0:
                print('Observations tested so far: ' + str(i))

    reports = {}
    target_names = ['Visually same', 'Visually different']
    print('REPORT\n')
    print('OVERALL')
    reports['overall'] = classification_report(
        [p['target'] for p in predictions],
        [p['prediction'] for p in predictions],
        target_names=target_names,
        output_dict=True)
    print(
        classification_report([p['target'] for p in predictions],
                              [p['prediction'] for p in predictions],
                              target_names=target_names))

    print('USERS')
    for u in users:
        print('User ' + u)
        reports[u] = classification_report(
            [p['target'] for p in predictions if p['user'] == u],
            [p['prediction'] for p in predictions if p['user'] == u],
            target_names=target_names,
            output_dict=True)
        print(
            classification_report(
                [p['target'] for p in predictions if p['user'] == u],
                [p['prediction'] for p in predictions if p['user'] == u],
                target_names=target_names))
    print('WEBPAGES')
    for wp in webpages:
        print('Webpage ' + wp)
        reports[wp] = classification_report(
            [p['target'] for p in predictions if p['webpage'] == wp],
            [p['prediction'] for p in predictions if p['webpage'] == wp],
            target_names=target_names,
            output_dict=True)
        print(
            classification_report(
                [p['target'] for p in predictions if p['webpage'] == wp],
                [p['prediction'] for p in predictions if p['webpage'] == wp],
                target_names=target_names))

    print('NOT ROUNDED OVERALL')
    sorted_preds = sorted(predictions, key=lambda p: p['not_rounded'])
    print('Split at ' +
          str(sorted_preds[(len(all_obs) - labeled)]['not_rounded']))
    for pred in sorted_preds[:(len(all_obs) - labeled)]:
        pred['prediction'] = 0
    for pred in sorted_preds[(len(all_obs) - labeled):]:
        pred['prediction'] = 1
    reports['not_rounded_overall'] = classification_report(
        [p['target'] for p in sorted_preds],
        [p['prediction'] for p in sorted_preds],
        target_names=target_names,
        output_dict=True)
    print(
        classification_report([p['target'] for p in sorted_preds],
                              [p['prediction'] for p in sorted_preds],
                              target_names=target_names))

    report_file_name = settings['model_name'] + '_report.json'
    report_file = open(settings['data_folder'] / report_file_name, 'w')
    json.dump(reports, report_file)
    report_file.close()
    print('Report saved as ' + report_file_name)
Esempio n. 14
0
        gt_data = pkl.loads(fh.read())
    fh.close()
    gt_cqts = gt_data['gt_cqts']
    gt_samples = gt_data['gt_samples']
    gt_sampling_freqs = gt_data['gt_sampling_freqs']

    with open("pred_data.pkl", 'rb') as fh:
        pred_data = pkl.loads(fh.read())
    fh.close()
    pred_cqts = pred_data['pred_cqts']
    pred_samples = pred_data['pred_samples']
    pred_sampling_freqs = pred_data['pred_sampling_freqs']

    #plt.figure(figsize=(16,12))
    #plt.imshow(merge_images(gt_cqts[:16], pred_cqts[:16]))
    return gt_samples, gt_sampling_freqs, pred_samples, pred_sampling_freqs


if __name__ == "__main__":
    net = cnn.Net().double().to(device)
    train_data, test_data, val_data, eval_data = load_data()
    train_model(net, train_data, val_data, eval_data)
    test(net, test_data)
    #plot_curves()
    #gt_samples, gt_sampling_freqs, pred_samples, pred_sampling_freqs = plot_cqts()
#     for i in range(len(gt_samples)):
#         print("GT")
#         display(Audio(gt_samples[i], rate=2*gt_sampling_freqs[i]))
#         print("Pred")
#         display(Audio(pred_samples[i], rate=2*pred_sampling_freqs[i]))
Esempio n. 15
0
            sensitivity = occlusion_sensitivity(model,
                                                images,
                                                ids[:, [i]],
                                                patch=p,
                                                stride=stride,
                                                n_batches=n_batches)

            # Save results as image files
            for j in range(len(images)):
                print("\t#{}: {} ({:.5f})".format(j, classes[ids[j, i]],
                                                  probs[j, i]))

                # save_sensitivity(
                #     filename=osp.join(
                #         output_dir,
                #         "{}-{}-sensitivity-{}-{}.png".format(
                #             j, arch, p, classes[ids[j, i]]
                #         ),
                #     ),
                #     maps=sensitivity[j],
                # )


if __name__ == "__main__":

    classes = ['crossing', 'klaxon', 'noise']

    tms = _tms.factory()
    net = cnn.Net(tms.input_size)
    main(net, classes, tms.input_size)
Esempio n. 16
0
def train(arguments):
    # print help information
    if 'help' in arguments:
        printer.print_help()
        sys.exit()
    else:
        printer.print_help_notice()

    # create settings from arguments
    settings = {
        'data_folder': Path(arguments[1]),
        'masked': ('masked' in arguments),
        'scrolling': ('scrolling' in arguments),
        'has_cuda': ('cuda' in arguments),
        'root': ('root' in arguments),
        'balanced': ('balanced' in arguments),
        'overlap': ('overlap' in arguments),
        'model_name': arguments[2]
    }

    # filter users and webpages to use
    users = names.get_all_users()
    users_to_use = list(set(users) & set(arguments))
    if users_to_use:
        users = users_to_use
    else:
        users = users[1:]

    webpages = names.get_all_webpages()
    webpages_to_use = list(set(webpages) & set(arguments))
    if webpages_to_use:
        webpages = webpages_to_use

    # print settings
    printer.print_settings(settings, users, users_to_use, webpages,
                           webpages_to_use)

    # create dict from CSV files
    csv_dict = dict_creator.create_dict(users, webpages, settings)

    # CNN
    net = cnn.Net()
    criterion = nn.BCEWithLogitsLoss()
    if 'load' in arguments:
        state_file = settings['model_name'] + '.pth'
        net.load_state_dict(torch.load(settings['data_folder'] / state_file))
        print('Model loaded successfully')

    if 'loadautosave' in arguments:
        state_file = 'autosave.pth'
        net.load_state_dict(torch.load(settings['data_folder'] / state_file))
        print('Model loaded successfully')

    if settings['has_cuda']:
        device = torch.device("cuda")
        net = net.cuda()
        criterion = criterion.cuda()

    optimizer = optim.Adam(net.parameters(), lr=0.0001)
    i = 0
    vids = {}
    for user in users:
        vids[user] = {}
        for webpage in webpages:
            vids[user][webpage] = cv2.VideoCapture(
                get_video_file(user, webpage, settings))

    torch.cuda.empty_cache()

    # check if the observations are getting ORed or are separate by layers and their masks
    if settings['masked'] or settings['scrolling'] or settings[
            'root'] or settings['overlap']:
        all_obs = [
            ob for user in users for webpage in webpages
            for ob in csv_dict[user][webpage]['features_meta']
        ]
    else:
        all_obs = [
            ob for user in users for webpage in webpages
            for ob in csv_dict[user][webpage]['labeled_obs']
        ]

    if settings['overlap']:
        all_obs = list(
            filter(
                lambda o: (int(o['overlap_height']) > 32) and
                (int(o['overlap_width']) > 32), all_obs))

    # split obs into training data and testing data
    # random_seed is used so test_cnn.py will have the same list to work with
    split_percentage = int(
        next((s for s in arguments if 'split' in s), 'split=0').split('=')[-1])
    if split_percentage:
        if split_percentage > 99 or split_percentage < 1:
            print('Percentage has to be between 0 and 100, splitting aborted')
        else:
            print('Splitting data:')
            print(str(len(all_obs)) + ' observations total')
            random_seed = next((s for s in arguments if 'seed' in s),
                               'seed=' + settings['model_name']).split('=')[-1]
            random.Random(random_seed).shuffle(all_obs)
            all_obs = all_obs[:int(len(all_obs) * split_percentage / 100)]
            print(str(len(all_obs)) + ' observations kept')

    # training parameters
    batch_size = int(
        next((s for s in arguments if 'batch_size' in s),
             'batch_size=10').split('=')[-1])
    epochs = int(
        next((s for s in arguments if 'epochs' in s),
             'epochs=5').split('=')[-1])
    running_loss = 0.0
    inputs = []
    targets = []

    # split obs between labels
    obs1, obs2 = partition(all_obs, lambda x: x.get('label') > 0)
    random.shuffle(obs1)
    random.shuffle(obs2)

    # start training
    printer.print_training_start(obs1, obs2, batch_size, settings)
    for epoch in range(epochs):  # loop over the dataset multiple times
        print('Epoch ' + Fore.YELLOW + str(epoch + 1) + '/' + str(epochs) +
              Style.RESET_ALL + ' started')
        obs1, obs2, next_obs = prepare_next_observations(obs1, obs2, settings)
        for ob in next_obs:
            input_frame = imp.get_merged_frame_pair(
                vids[ob['user']][ob['webpage']], ob, settings)
            # swap color axis because
            # numpy image: H x W x C
            # torch image: C X H X W
            input_frame = torch.from_numpy(input_frame.transpose((2, 0, 1)))
            inputs.append(input_frame)
            targets.append([min(1, ob['label'])])

            if len(inputs) >= batch_size:

                inputs = torch.stack(inputs, dim=0)
                # print(inputs.shape)
                # inputs = torch.from_numpy(inputs)
                targets = torch.Tensor(targets)

                if settings['has_cuda']:
                    inputs = inputs.cuda()
                    targets = targets.cuda()

                # forward + backward + optimize
                output = net(inputs)
                loss = criterion(output, targets)

                # zero the parameter gradients
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Clean up GPU memory
                del inputs
                del targets
                torch.cuda.empty_cache()

                # current loss
                running_loss += loss.item() * batch_size
                inputs = []
                targets = []

            # print statistics
            if i % 200 == 199:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
            i += 1
        autosave_path = settings['data_folder'] / 'autosave.pth'
        torch.save(net.state_dict(), autosave_path)
        print('Autosave')
        if ('saveall' in arguments) and (epoch < (epochs - 1)):
            file_name = settings['model_name'] + '_afterEpoch' + str(
                epoch + 1) + '.pth'
            PATH = settings['data_folder'] / file_name
            torch.save(net.state_dict(), PATH)

    print('Finished Training')

    file_name = settings['model_name'] + '.pth'
    PATH = settings['data_folder'] / file_name
    torch.save(net.state_dict(), PATH)

    print('Saved as ' + file_name)