コード例 #1
0
def main():
    args = parse_args()
    train_dataset, test_dataset = dataset.get_dataset(args.path,
                                                      args.use_augmentation,
                                                      args.use_fivecrop)
    train_loader = DataLoader(train_dataset,
                              args.batch,
                              True,
                              num_workers=args.worker,
                              pin_memory=True)
    test_loader = DataLoader(test_dataset,
                             args.batch,
                             False,
                             num_workers=args.worker,
                             pin_memory=True)
    if args.cuda:
        torch.cuda.set_device(0)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    if args.model == 'ResNet18':
        mymodel = model.ResNet18(args.frozen_layers).to(device)
    elif args.model == 'ResNet34':
        mymodel = model.ResNet34(args.frozen_layers).to(device)
    elif args.model == 'ResNet50':
        mymodel = model.ResNet50(args.frozen_layers).to(device)
    elif args.model == 'DenseNet':
        mymodel = model.DenseNet().to(device)
    else:
        pass
    op = optim.Adam(mymodel.parameters(), lr=args.lr)
    train_losses, test_mF1s, test_precisions, test_recalls = [], [], [], []
    early = args.early
    for i in range(args.epoch):
        train_loss = train.train(mymodel, op, train_loader, i, device,
                                 args.log, utils.pos_weight)
        mF1, recall, presicion = test.test(mymodel, test_loader, device,
                                           args.use_fivecrop)
        train_losses.append(train_loss)
        test_mF1s.append(mF1)
        test_precisions.append(presicion)
        test_recalls.append(recall)
        early = utils.early_stop(test_mF1s, early)
        if early <= 0:
            break
    utils.save_log(mymodel, train_losses, test_mF1s, test_precisions,
                   test_recalls)
コード例 #2
0
def main():
    args = parse_args()
    save = True
    use_gpu = args.cuda == 'True'
    load = args.load == 'True'
    train_tfs = transforms.Compose([
        transforms.Resize(299),
        transforms.RandomSizedCrop(299),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    ds = my_dataset("train", train_tfs)
    dataset_size = ds.__len__()
    print(dataset_size)
    train_ds, val_ds = torch.utils.data.random_split(ds, [13000, 1463])
    train_loader = torch.utils.data.DataLoader(train_ds,
                                               args.BS,
                                               False,
                                               num_workers=8)
    val_loader = torch.utils.data.DataLoader(val_ds,
                                             args.BS,
                                             False,
                                             num_workers=8)
    print('train: ', len(train_ds))
    print('validation:', len(val_ds))
    print(type(ds), type(train_ds))
    if args.model == 'ResNet18':
        test_model = model.ResNet18()
    if args.model == 'ResNet50':
        test_model = model.ResNet50()
    if args.model == 'Inception':
        test_model = model.Inception()
    if args.model == 'DenseNet':
        test_model = model.DenseNet()
    if use_gpu:
        test_model = test_model.cuda()
    if load:
        test_model.load_state_dict(torch.load('params' + args.model + '.pkl'))
    optimizer = optim.Adam(test_model.parameters(), lr=args.lr)
    print(use_gpu)
    result = train(test_model, args.epoch, optimizer, train_loader, val_loader,
                   args.model, save, use_gpu)
    test(result, val_loader, use_gpu)
コード例 #3
0
                                     num_workers=20)
        NUM_CLASSES = DATASET.num_classes
        print('Data path: ' + args.data_path)
        print('Number of classes: %d' % NUM_CLASSES)
        print('Batch size: %d' % args.batch_size)
        print('Epoch size: %d' % args.epoch_size)

        num_batches = len(DATALOADER)
        batch_total = num_batches * args.epoch_size

        if args.model == 'ResNet18':
            MODEL = model.ResNet18(pseudo=args.pseudo)
        elif args.model == 'ResNet34':
            MODEL = model.ResNet34(pseudo=args.pseudo)
        elif args.model == 'ResNet50':
            MODEL = model.ResNet50(pseudo=args.pseudo)
        elif args.model == 'ResNet101':
            MODEL = model.ResNet101(pseudo=args.pseudo)

        ARCFACE = lossfunction.Arcface(512, NUM_CLASSES)

        if args.optim == 'Adam':
            OPTIMIZER = torch.optim.Adam([{
                'params': MODEL.parameters()
            }, {
                'params': ARCFACE.parameters()
            }],
                                         lr=1e-4)
            SCHEDULER = torch.optim.lr_scheduler.MultiStepLR(OPTIMIZER,
                                                             milestones=[10],
                                                             gamma=0.5)
コード例 #4
0
def train_model(output_folder, batch_size, reader_count, train_lmdb_filepath,
                test_lmdb_filepath, use_augmentation, learning_rate,
                test_every_n_steps, early_stopping_count):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    training_checkpoint_filepath = None

    # # setup mixed precision training to use FP16
    # policy = mixed_precision.Policy('mixed_float16')
    # mixed_precision.set_policy(policy)

    # uses all available devices
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():

        # scale the batch size based on the GPU count
        global_batch_size = batch_size * mirrored_strategy.num_replicas_in_sync
        # scale the number of I/O readers based on the GPU count
        reader_count = reader_count * mirrored_strategy.num_replicas_in_sync

        print('Setting up test image reader')
        test_reader = imagereader.ImageReader(test_lmdb_filepath,
                                              use_augmentation=False,
                                              shuffle=False,
                                              num_workers=reader_count)
        print('Test Reader has {} images'.format(
            test_reader.get_image_count()))

        print('Setting up training image reader')
        train_reader = imagereader.ImageReader(
            train_lmdb_filepath,
            use_augmentation=use_augmentation,
            shuffle=True,
            num_workers=reader_count)
        print('Train Reader has {} images'.format(
            train_reader.get_image_count()))

        try:  # if any errors happen we want to catch them and shut down the multiprocess readers
            print('Starting Readers')
            train_reader.startup()
            test_reader.startup()

            train_dataset = train_reader.get_tf_dataset()
            train_dataset = train_dataset.batch(global_batch_size).prefetch(
                reader_count)
            train_dataset = mirrored_strategy.experimental_distribute_dataset(
                train_dataset)

            test_dataset = test_reader.get_tf_dataset()
            test_dataset = test_dataset.batch(global_batch_size).prefetch(
                reader_count)
            test_dataset = mirrored_strategy.experimental_distribute_dataset(
                test_dataset)

            print('Creating model')
            renset = model.ResNet50(global_batch_size,
                                    train_reader.get_image_size(),
                                    learning_rate)

            checkpoint = tf.train.Checkpoint(optimizer=renset.get_optimizer(),
                                             model=renset.get_keras_model())

            # train_epoch_size = train_reader.get_image_count()/batch_size
            train_epoch_size = test_every_n_steps
            test_epoch_size = test_reader.get_image_count() / batch_size

            test_loss = list()

            # Prepare the metrics.
            train_loss_metric = tf.keras.metrics.Mean('train_loss',
                                                      dtype=tf.float32)
            test_loss_metric = tf.keras.metrics.Mean('test_loss',
                                                     dtype=tf.float32)

            current_time = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
            train_log_dir = os.path.join(output_folder,
                                         'tensorboard-' + current_time,
                                         'train')
            if not os.path.exists(train_log_dir):
                os.makedirs(train_log_dir)
            test_log_dir = os.path.join(output_folder,
                                        'tensorboard-' + current_time, 'test')
            if not os.path.exists(test_log_dir):
                os.makedirs(test_log_dir)

            train_summary_writer = tf.summary.create_file_writer(train_log_dir)
            test_summary_writer = tf.summary.create_file_writer(test_log_dir)

            epoch = 0
            print('Running Network')
            while True:  # loop until early stopping
                print('---- Epoch: {} ----'.format(epoch))

                if epoch == 0:
                    cur_train_epoch_size = min(1000, train_epoch_size)
                    print(
                        'Performing Adam Optimizer learning rate warmup for {} steps'
                        .format(cur_train_epoch_size))
                    renset.set_learning_rate(learning_rate / 10)
                else:
                    cur_train_epoch_size = train_epoch_size
                    renset.set_learning_rate(learning_rate)

                # Iterate over the batches of the train dataset.
                start_time = time.time()
                for step, (batch_images,
                           batch_labels) in enumerate(train_dataset):
                    if step > cur_train_epoch_size:
                        break

                    inputs = (batch_images, batch_labels, train_loss_metric)
                    renset.dist_train_step(mirrored_strategy, inputs)

                    print('Train Epoch {}: Batch {}/{}: Loss {}'.format(
                        epoch, step, train_epoch_size,
                        train_loss_metric.result()))
                    with train_summary_writer.as_default():
                        tf.summary.scalar('loss',
                                          train_loss_metric.result(),
                                          step=int(epoch * train_epoch_size +
                                                   step))
                    train_loss_metric.reset_states()

                # Iterate over the batches of the test dataset.
                epoch_test_loss = list()
                for step, (batch_images,
                           batch_labels) in enumerate(test_dataset):
                    if step > test_epoch_size:
                        break

                    inputs = (batch_images, batch_labels, test_loss_metric)
                    loss_value = renset.dist_test_step(mirrored_strategy,
                                                       inputs)

                    epoch_test_loss.append(loss_value.numpy())
                    # print('Test Epoch {}: Batch {}/{}: Loss {}'.format(epoch, step, test_epoch_size, loss_value))
                test_loss.append(np.mean(epoch_test_loss))

                print('Test Epoch: {}: Loss = {}'.format(
                    epoch, test_loss_metric.result()))
                with test_summary_writer.as_default():
                    tf.summary.scalar('loss',
                                      test_loss_metric.result(),
                                      step=int((epoch + 1) * train_epoch_size))
                test_loss_metric.reset_states()

                with open(os.path.join(output_folder, 'test_loss.csv'),
                          'w') as csvfile:
                    for i in range(len(test_loss)):
                        csvfile.write(str(test_loss[i]))
                        csvfile.write('\n')

                print('Epoch took: {} s'.format(time.time() - start_time))

                # determine if to record a new checkpoint based on best test loss
                if (len(test_loss) - 1) == np.argmin(test_loss):
                    # save tf checkpoint
                    print('Test loss improved: {}, saving checkpoint'.format(
                        np.min(test_loss)))
                    # checkpoint.save(os.path.join(output_folder, 'checkpoint', "ckpt")) # does not overwrite
                    training_checkpoint_filepath = checkpoint.write(
                        os.path.join(output_folder, 'checkpoint', "ckpt"))

                # determine early stopping
                CONVERGENCE_TOLERANCE = 1e-4
                print('Best Current Epoch Selection:')
                print('Test Loss:')
                print(test_loss)
                min_test_loss = np.min(test_loss)
                error_from_best = np.abs(test_loss - min_test_loss)
                error_from_best[error_from_best < CONVERGENCE_TOLERANCE] = 0
                best_epoch = np.where(error_from_best == 0)[0][
                    0]  # unpack numpy array, select first time since that value has happened
                print('Best epoch: {}'.format(best_epoch))

                if len(test_loss) - best_epoch > early_stopping_count:
                    break  # break the epoch loop
                epoch = epoch + 1

        finally:  # if any erros happened during training, shut down the disk readers
            print('Shutting down train_reader')
            train_reader.shutdown()
            print('Shutting down test_reader')
            test_reader.shutdown()

    # convert training checkpoint to the saved model format
    if training_checkpoint_filepath is not None:
        # restore the checkpoint and generate a saved model
        renset = model.ResNet50(global_batch_size,
                                train_reader.get_image_size(), learning_rate)
        checkpoint = tf.train.Checkpoint(optimizer=renset.get_optimizer(),
                                         model=renset.get_keras_model())
        checkpoint.restore(training_checkpoint_filepath)
        tf.saved_model.save(renset.get_keras_model(),
                            os.path.join(output_folder, 'saved_model'))
コード例 #5
0
from pre_process import transform_train, transform_test

# whether use gpu
use_cuda = torch.cuda.is_available()

# default parameters
DATA_ROOT = '../data/'
num_epochs = 50
batch_size = 128

model_names = {
    'dnn': model.DNN(3072, 4096, 10),
    'cnn': model.CNN(),
    'resnet18': model.ResNet18(),
    'resnet34': model.ResNet34(),
    'resnet50': model.ResNet50()
}


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_type',
                        type=str,
                        default='dnn',
                        help="the type of model")
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        help='the initial learning rate')
    parser.add_argument('--batch_size',
                        type=int,
コード例 #6
0
    k_vals = [1, 2, 4, 8]
    parser.add_argument('--dataset_path',
                        default='../stanford-cars-dataset',
                        help='path for input data')
    parser.add_argument(
        '--num_classes',
        type=int,
        default=196,
        help=
        'Number of classes in dataset. Default is num_classes in CARS 196 train data'
    )
    args = parser.parse_args()
    dataloaders = CARS_196_Loader.give_CARS_dataloaders(args)

#dataloaders = loader.give_dataloaders(args)
# print(len(dataloaders['training'].dataset))
model = net.ResNet50(args)
if (torch.cuda.device_count() > 1):
    model = nn.DataParallel(model)
model.to(device)

if (args.load_model):
    model.load_state_dict(torch.load(args.model_dict_path))

if (not args.query):
    train.train(args, model, dataloaders, k_vals)
else:
    print(
        query.query(args, dataloaders['testing'], dataloaders['evaluation'],
                    model))
コード例 #7
0
ファイル: main.py プロジェクト: chhkang/ResNet
def main():
    global args, start_epoch, best_acc1
    args = config()

    if args.cuda and not torch.cuda.is_available():
        raise Exception('No GPU found, please run without --cuda')

    print('\n=> Build ResNet..')
    model = mo.ResNet50()
    print(model)
    print('==> Complete build')

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay,
                          nesterov=True)
    start_epoch = 0
    n_retrain = 0

    if args.cuda:
        torch.cuda.set_device(args.gpuids[0])
        with torch.cuda.device(args.gpuids[0]):
            model = model.cuda()
            criterion = criterion.cuda()
        model = nn.DataParallel(model,
                                device_ids=args.gpuids,
                                output_device=args.gpuids[0])
        cudnn.benchmark = True

    # checkpoint file
    ckpt_dir = pathlib.Path('checkpoint')
    ckpt_file = ckpt_dir / args.dataset / args.ckpt

    # for resuming training
    if args.resume:
        if isfile(ckpt_file):
            print('\n==> Loading Checkpoint \'{}\''.format(args.ckpt))
            checkpoint = load_model(model, ckpt_file, args)

            start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])

            print('==> Loaded Checkpoint \'{}\' (epoch {})'.format(
                args.ckpt, start_epoch))
        else:
            print('==> no checkpoint found \'{}\''.format(args.ckpt))
            return

    # Data loading
    print('\n==> Load data..')
    train_loader, val_loader = DataLoader(args.batch_size, args.workers,
                                          args.datapath, args.cuda)

    # for evaluation
    if args.evaluate:
        if isfile(ckpt_file):
            print('\n==> Loading Checkpoint \'{}\''.format(args.ckpt))
            checkpoint = load_model(model, ckpt_file, args)

            print('==> Loaded Checkpoint \'{}\' (epoch {})'.format(
                args.ckpt, start_epoch))

            # evaluate on validation set
            print('\n===> [ Evaluation ]')
            start_time = time.time()
            acc1, acc5 = validate(val_loader, model, criterion)
            elapsed_time = time.time() - start_time
            print('====> {:.2f} seconds to evaluate this model\n'.format(
                elapsed_time))
            return
        else:
            print('==> no checkpoint found \'{}\''.format(args.ckpt))
            return

    # train...
    train_time = 0.0
    validate_time = 0.0
    lr = args.lr
    list_Acc1 = []
    list_Acc5 = []
    list_epoch = []
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, lr)
        print('\n==> Epoch: {}, lr = {}'.format(
            epoch, optimizer.param_groups[0]["lr"]))

        # train for one epoch
        print('===> [ Training ]')
        start_time = time.time()
        acc1_train, acc5_train = train(train_loader,
                                       epoch=epoch,
                                       model=model,
                                       criterion=criterion,
                                       optimizer=optimizer)
        elapsed_time = time.time() - start_time
        train_time += elapsed_time
        print(
            '====> {:.2f} seconds to train this epoch\n'.format(elapsed_time))

        # evaluate on validation set
        print('===> [ Validation ]')
        start_time = time.time()
        acc1_valid, acc5_valid = validate(val_loader, model, criterion)
        elapsed_time = time.time() - start_time
        validate_time += elapsed_time
        print('====> {:.2f} seconds to validate this epoch\n'.format(
            elapsed_time))

        # remember best Acc@1 and save checkpoint
        is_best = acc1_valid > best_acc1
        best_acc1 = max(acc1_valid, best_acc1)
        state = {
            'epoch': epoch + 1,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        save_model(state, epoch, is_best, args)
        list_Acc1.append(acc1_valid)
        list_Acc5.append(acc5_valid)
        list_epoch.append(epoch)
    plt.plot(list_epoch, list_Acc1)
    plt.plot(list_epoch, list_Acc5)
    plt.lengend(['ACC1', 'ACC5'])
    avg_train_time = train_time / (args.epochs - start_epoch)
    avg_valid_time = validate_time / (args.epochs - start_epoch)
    total_train_time = train_time + validate_time
    print('====> average training time per epoch: {:,}m {:.2f}s'.format(
        int(avg_train_time // 60), avg_train_time % 60))
    print('====> average validation time per epoch: {:,}m {:.2f}s'.format(
        int(avg_valid_time // 60), avg_valid_time % 60))
    print('====> training time: {}h {}m {:.2f}s'.format(
        int(train_time // 3600), int((train_time % 3600) // 60),
        train_time % 60))
    print('====> validation time: {}h {}m {:.2f}s'.format(
        int(validate_time // 3600), int((validate_time % 3600) // 60),
        validate_time % 60))
    print('====> total training time: {}h {}m {:.2f}s'.format(
        int(total_train_time // 3600), int((total_train_time % 3600) // 60),
        total_train_time % 60))
コード例 #8
0
    recall = correct / label_cnt
    f1 = 2 * (precision * recall) / (precision + recall)
    wf1 = torch.sum(weights * f1).item()
    mf1 = torch.mean(f1).item()
    print('class f1', f1, ' wf1', wf1, ' mf1', mf1)

    test_precision = 100. * precision.mean().item()
    test_recall = 100. * recall.mean().item()
    print('\nTest set: recall: {}/{} ({:.0f}%), precision: {}/{} ({:.0f}%)\n'.
          format(correct.sum().item(),
                 label_cnt.sum().item(), test_recall,
                 correct.sum().item(),
                 predict_cnt.sum().item(), test_precision))
    return mf1, test_recall, test_precision


if __name__ == "__main__":
    torch.cuda.set_device(3)
    device = torch.device('cuda')
    mymodel = model.ResNet50('none')
    mymodel.load('./ResNet50_bigdata_15005716.pt')
    mymodel = mymodel.to(device)
    FIVE = True
    _, test_dataset = dataset.get_dataset('./dataset', False, FIVE)
    loader = DataLoader(test_dataset,
                        32,
                        False,
                        num_workers=16,
                        pin_memory=True)
    test(mymodel, loader, device, FIVE)
コード例 #9
0
ファイル: train.py プロジェクト: vpamrit/DeepMed-Assignment1
def main(args):
    #device configuration
    if args.cpu != None:
        device = torch.device('cpu')
    elif args.gpu != None:
        if not torch.cuda_is_available():
            print("GPU / cuda reported as unavailable to torch")
            exit(0)
        device = torch.device('cuda')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create model directory
    if not os.path.exists(args.model_save_dir):
        os.makedirs(args.model_save_dir)

    train_data = ld.get_data(labels_file=args.labels_file,
                             root_dir=args.train_image_dir,
                             mode="absolute")

    validation_data = ld.get_data(labels_file=args.labels_file,
                                  root_dir=args.validation_image_dir,
                                  mode="absolute")

    train_loader = DataLoader(dataset=train_data,
                              batch_size=args.batch_size,
                              shuffle=True)
    val_loader = DataLoader(dataset=validation_data,
                            batch_size=args.validation_batch_size)

    # Build the models
    if args.num_layers != None and args.block_type != None:
        if args.block_type == "bottleneck":
            net = model.ResNet(model.Bottleneck,
                               args.num_layers,
                               dropout=args.dropout)
        else:
            net = model.ResNet(model.BasicBlock,
                               args.num_layers,
                               dropout=args.dropout)
    else:
        if args.resnet_model == 152:
            net = model.ResNet152(args.dropout)
        elif args.resnet_model == 101:
            net = model.ResNet101(args.dropout)
        elif args.resnet_model == 50:
            net = model.ResNet50(args.dropout)
        elif args.resnet_model == 34:
            net = model.ResNet34(args.dropout)
        else:
            net = model.ResNet101(args.dropout)

    #load the model to the appropriate device
    net = net.to(device)
    params = net.parameters()

    # Loss and optimizer
    criterion = nn.MSELoss()  #best for regression

    if args.optim != None:
        if args.optim == "adadelta":
            optimizer = torch.optim.Adadelta(params, lr=args.learning_rate)
        if args.optim == "adagrad":
            optimizer = torch.optim.Adagrad(params, lr=args.learning_rate)
        if args.optim == "adam":
            optimizer = torch.optim.Adam(params, lr=args.learning_rate)
        if args.optim == "adamw":
            optimizer = torch.optim.AdamW(params, lr=args.learning_rate)
        if args.optim == "rmsprop":
            optimizer = torch.optim.RMSProp(params, lr=args.learning_rate)
        if args.optim == "sgd":
            optimizer = torch.optim.SGD(params, lr=args.learning_rate)
    else:
        optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    val_acc_history = []
    train_acc_history = []
    failed_runs = 0
    prev_loss = float("inf")

    for epoch in range(args.num_epochs):
        running_loss = 0.0
        total_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader, 0):
            net.train()

            #adjust to output image coordinates
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs.float())
            loss = criterion(outputs.float(), labels.float())
            loss.backward()

            torch.nn.utils.clip_grad_norm_(net.parameters(),
                                           args.clipping_value)
            optimizer.step()
            running_loss += loss.item()
            total_loss += loss.item()
            if i % 2 == 0:  #print every mini-batches
                print('[%d, %5d] loss: %.5f' %
                      (epoch + 1, i + 1, running_loss / 2))
                running_loss = 0.0

        loss = 0.0

        #compute validation loss at the end of the epoch
        for i, (inputs, labels) in enumerate(val_loader, 0):
            inputs, labels = inputs.to(device), labels.to(device)
            net.eval()
            with torch.no_grad():
                outputs = net(inputs.float())
                loss += criterion(outputs, labels.float()).item()

        print("------------------------------------------------------------")
        print("Epoch %5d" % (epoch + 1))
        print("Training loss: {}, Avg Loss: {}".format(
            total_loss, total_loss / train_data.__len__()))
        print("Validation Loss: {}, Avg Loss: {}".format(
            loss, loss / validation_data.__len__()))
        print("------------------------------------------------------------")

        val_acc_history.append(loss)
        train_acc_history.append(total_loss)

        #save the model at the desired step
        if (epoch + 1) % args.save_step == 0:
            torch.save(net.state_dict(),
                       args.model_save_dir + "resnet" + str(epoch + 1) + ".pt")

        ##stopping conditions
        if failed_runs > 5 and prev_loss < loss:
            break
        elif prev_loss < loss:
            failed_runs += 1
        else:
            failed_runs = 0

        prev_loss = loss

    #create a plot of the loss
    plt.title("Training vs Validation Accuracy")
    plt.xlabel("Training Epochs")
    plt.ylabel("Loss")
    plt.plot(range(1,
                   len(val_acc_history) + 1),
             val_acc_history,
             label="Validation loss")
    plt.plot(range(1,
                   len(train_acc_history) + 1),
             train_acc_history,
             label="Training loss")
    plt.xticks(np.arange(1, len(train_acc_history) + 1, 1.0))
    plt.legend()
    plt.ylim((0, max([max(val_acc_history), max(train_acc_history)])))

    if args.save_training_plot != None:
        plt.savefig(args.save_training_plot + "loss_plot.png")

    plt.show()
    print('Finished Training')