コード例 #1
0
def main():
    """Print performance metrics for model at specified epoch."""
    # Data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        task="target",
        batch_size=config("cnn.batch_size"),
    )

    # Model
    model = Target()

    # define loss function
    criterion = torch.nn.CrossEntropyLoss()

    # Attempts to restore the latest checkpoint if exists
    print("Loading cnn...")
    model, start_epoch, stats = restore_checkpoint(model,
                                                   config("cnn.checkpoint"))

    axes = utils.make_training_plot()

    # Evaluate the model
    evaluate_epoch(
        axes,
        tr_loader,
        va_loader,
        te_loader,
        model,
        criterion,
        start_epoch,
        stats,
        include_test=True,
        update_plot=False,
    )
コード例 #2
0
def main():
    # data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        num_classes=config('autoencoder.classifier.num_classes'))

    ae_classifier = AutoencoderClassifier(config('autoencoder.ae_repr_dim'),
                                          config('autoencoder.classifier.num_classes'))
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(ae_classifier.parameters(),
                                 lr=config('autoencoder.classifier.learning_rate'))

    # freeze the weights of the encoder
    for name, param in ae_classifier.named_parameters():
        if 'fc1.' in name or 'fc2.' in name:
            param.requires_grad = False

    # Attempts to restore the latest checkpoint if exists
    print('Loading autoencoder...')
    ae_classifier, _, _ = restore_checkpoint(ae_classifier,
                                             config('autoencoder.checkpoint'), force=True, pretrain=True)
    print('Loading autoencoder classifier...')
    ae_classifier, start_epoch, stats = restore_checkpoint(ae_classifier,
                                                           config('autoencoder.classifier.checkpoint'))

    axes = utils.make_cnn_training_plot(name='Autoencoder Classifier')

    # Evaluate the randomly initialized model
    _evaluate_epoch(axes, tr_loader, va_loader, ae_classifier, criterion,
                    start_epoch, stats)

    # Loop over the entire dataset multiple times
    for epoch in range(start_epoch, config('autoencoder.classifier.num_epochs')):
        # Train model
        _train_epoch(tr_loader, ae_classifier, criterion, optimizer)

        # Evaluate model
        _evaluate_epoch(axes, tr_loader, va_loader, ae_classifier, criterion,
                        epoch + 1, stats)

        # Save model parameters
        save_checkpoint(ae_classifier, epoch + 1,
                        config('autoencoder.classifier.checkpoint'), stats)

    print('Finished Training')
    with torch.no_grad():
        y_true, y_pred = [], []
        correct, total = 0, 0
        running_loss = []
        for X, y in va_loader:
            output = ae_classifier(X)
            predicted = predictions(output.data)
            y_true.extend(y)
            y_pred.extend(predicted)
        print("Validation data accuracies:")
        print(confusion_matrix(y_true, y_pred))


    # Keep plot open
    utils.save_cnn_training_plot(name='ae_clf')
    utils.hold_training_plot()
コード例 #3
0
def main():
    """Train transfer learning model and display training plots.

    Train four different models with {0, 1, 2, 3} layers frozen.
    """
    # data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        task="target",
        batch_size=config("target.batch_size"),
    )

    freeze_none = Target()
    print("Loading source...")
    freeze_none, _, _ = restore_checkpoint(
        freeze_none, config("source.checkpoint"), force=True, pretrain=True
    )

    freeze_one = copy.deepcopy(freeze_none)
    freeze_two = copy.deepcopy(freeze_none)
    freeze_three = copy.deepcopy(freeze_none)

    freeze_layers(freeze_one, 1)
    freeze_layers(freeze_two, 2)
    freeze_layers(freeze_three, 3)

    train(tr_loader, va_loader, te_loader, freeze_none, "./checkpoints/target0/", 0)
    train(tr_loader, va_loader, te_loader, freeze_one, "./checkpoints/target1/", 1)
    train(tr_loader, va_loader, te_loader, freeze_two, "./checkpoints/target2/", 2)
    train(tr_loader, va_loader, te_loader, freeze_three, "./checkpoints/target3/", 3)
コード例 #4
0
def main():
    # Data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        num_classes=config('cnn.num_classes'))

    # Model
    model = CNN()

    # TODO: define loss function, and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)
    #

    print('Number of float-valued parameters:', count_parameters(model))

    # Attempts to restore the latest checkpoint if exists
    print('Loading cnn...')
    model, start_epoch, stats = restore_checkpoint(model,
                                                   config('cnn.checkpoint'))

    axes = utils.make_cnn_training_plot()

    # Evaluate the randomly initialized model
    _evaluate_epoch(axes, tr_loader, va_loader, model, criterion, start_epoch,
                    stats)

    # Loop over the entire dataset multiple times
    for epoch in range(start_epoch, config('cnn.num_epochs')):
        # Train model
        _train_epoch(tr_loader, model, criterion, optimizer)

        # Evaluate model
        _evaluate_epoch(axes, tr_loader, va_loader, model, criterion,
                        epoch + 1, stats)

        # Save model parameters
        save_checkpoint(model, epoch + 1, config('cnn.checkpoint'), stats)

    print('Finished Training')

    y_true, y_pred = [], []
    correct, total = 0, 0
    running_loss = []
    for X, y in va_loader:
        with torch.no_grad():
            output = model(X)
            predicted = predictions(output.data)
            y_true.extend(y)
            y_pred.extend(predicted)
            total += y.size(0)
            correct += (predicted == y).sum().item()
            running_loss.append(criterion(output, y).item())
    print("Validation data accuracies:")
    print(confusion_matrix(y_true, y_pred))

    # Save figure and keep plot open
    utils.save_cnn_training_plot()
    utils.hold_training_plot()
コード例 #5
0
def main():
    # Data loaders
    tr_loader, va_loader, te_loader, get_semantic_labels = get_train_val_test_loaders(
        num_classes=config('cnn.num_classes'))

    # Model
    model = CNN()

    # TODO: define loss function, and optimizer
    params = list(model.conv1.parameters()) + list(
        model.conv2.parameters()) + list(model.conv3.parameters())
    params = params + list(model.fc1.parameters()) + list(
        model.fc2.parameters()) + list(model.fc3.parameters())
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params, lr=0.0001)
    #

    print('Number of float-valued parameters:', count_parameters(model))

    # Attempts to restore the latest checkpoint if exists
    print('Loading cnn...')
    model, start_epoch, stats = restore_checkpoint(model,
                                                   config('cnn.checkpoint'))

    fig, axes = utils.make_cnn_training_plot()

    # Evaluate the randomly initialized model
    _evaluate_epoch(axes, tr_loader, va_loader, model, criterion, start_epoch,
                    stats)

    # Loop over the entire dataset multiple times
    for epoch in range(start_epoch, config('cnn.num_epochs')):
        # Train model
        _train_epoch(tr_loader, model, criterion, optimizer)

        # Evaluate model
        _evaluate_epoch(axes, tr_loader, va_loader, model, criterion,
                        epoch + 1, stats)

        # Save model parameters
        save_checkpoint(model, epoch + 1, config('cnn.checkpoint'), stats)

    print('Finished Training')

    model, _, _ = restore_checkpoint(model, config('cnn.checkpoint'))

    dataset = get_data_by_label(va_loader)
    evaluate_cnn(dataset, model, criterion, get_semantic_labels)

    # Save figure and keep plot open
    utils.save_cnn_training_plot(fig)
    utils.hold_training_plot()
コード例 #6
0
def main():
    """Create confusion matrix and save to file."""
    tr_loader, va_loader, te_loader, semantic_labels = get_train_val_test_loaders(
        task="source", batch_size=config("source.batch_size"))

    model = Source()
    print("Loading source...")
    model, epoch, stats = restore_checkpoint(model,
                                             config("source.checkpoint"))

    sem_labels = "0 - Samoyed\n1 - Miniature Poodle\n2 - Saint Bernard\n3 - Great Dane\n4 - Dalmatian\n5 - Chihuahua\n6 - Siberian Husky\n7 - Yorkshire Terrier"

    # Evaluate model
    plot_conf(va_loader, model, sem_labels, "conf_matrix.png")
コード例 #7
0
def main(uniqname):
    # data loaders
    _, _, te_loader, get_semantic_label = get_train_val_test_loaders(
        num_classes=config('challenge.num_classes'))

    model = Challenge()

    # Attempts to restore the latest checkpoint if exists
    model, _, _ = restore_checkpoint(model, config('challenge.checkpoint'))

    # Evaluate model
    model_pred = predict_challenge(te_loader, model)

    print('saving challenge predictions...\n')
    model_pred = [get_semantic_label(p) for p in model_pred]
    pd_writer = pd.DataFrame(model_pred, columns=['predictions'])
    pd_writer.to_csv(uniqname + '.csv', index=False, header=False)
コード例 #8
0
def main():
    # data loaders
    _, va_loader, _, get_semantic_label = get_train_val_test_loaders(
        num_classes=config('autoencoder.num_classes'))
    dataset = get_data_by_label(va_loader)

    ae = Autoencoder(config('autoencoder.ae_repr_dim'))
    naive = NaiveRecon(config('autoencoder.naive_scale'))

    # Restore the latest checkpoint of autoencoder
    print('Loading autoencoder...')
    ae, _, _ = restore_checkpoint(ae,
                                  config('autoencoder.checkpoint'),
                                  force=True)

    # Visualize
    visualize_autoencoder(dataset, get_semantic_label, ae, naive)
コード例 #9
0
ファイル: train_cnn.py プロジェクト: linxiaow/CNN_Prototype
def main():
    # Data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        num_classes=config('cnn.num_classes'))

    # Model
    model = CNN()

    # TODO: define loss function, and optimizer
    import torch.optim as op
    import torch.nn as nn
    criterion = nn.CrossEntropyLoss()
    optimizer = op.Adam(model.parameters(), lr=config('cnn.learning_rate'))
    #

    print('Number of float-valued parameters:', count_parameters(model))

    # Attempts to restore the latest checkpoint if exists
    print('Loading cnn...')
    model, start_epoch, stats = restore_checkpoint(model,
        config('cnn.checkpoint'))

    axes = utils.make_cnn_training_plot()

    # Evaluate the randomly initialized model
    _evaluate_epoch(axes, tr_loader, va_loader, model, criterion, start_epoch,
        stats)
    
    # Loop over the entire dataset multiple times
    for epoch in range(start_epoch, config('cnn.num_epochs')):
        # Train model
        _train_epoch(tr_loader, model, criterion, optimizer)
        
        # Evaluate model
        _evaluate_epoch(axes, tr_loader, va_loader, model, criterion, epoch+1,
            stats)

        # Save model parameters
        save_checkpoint(model, epoch+1, config('cnn.checkpoint'), stats)

    print('Finished Training')

    # Save figure and keep plot open
    utils.save_cnn_training_plot()
    utils.hold_training_plot()
コード例 #10
0
def main():
    # data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        num_classes=config('autoencoder.num_classes'))

    # Model
    model = Autoencoder(config('autoencoder.ae_repr_dim'))

    # TODO: define loss function, and optimizer
    criterion = torch.nn.MSELoss()
    params = list(model.pool.parameters()) + list(
        model.fc1.parameters()) + list(model.fc2.parameters())
    params = params + list(model.fc3.parameters()) + list(
        model.deconv.parameters())
    optimizer = torch.optim.Adam(params, lr=0.0001)
    #

    # Attempts to restore the latest checkpoint if exists
    print('Loading autoencoder...')
    model, start_epoch, stats = restore_checkpoint(
        model, config('autoencoder.checkpoint'))

    fig, axes = utils.make_ae_training_plot()

    # Loop over the entire dataset multiple times
    for epoch in range(start_epoch, config('autoencoder.num_epochs')):
        # Evaluate model
        _evaluate_epoch(axes, tr_loader, va_loader, model, criterion,
                        epoch + 1, stats)

        # Train model
        _train_epoch(tr_loader, model, criterion, optimizer)
        _train_epoch(te_loader, model, criterion, optimizer)

        # Save model parameters
        save_checkpoint(model, epoch + 1, config('autoencoder.checkpoint'),
                        stats)

    print('Finished Training')

    # Save figure and keep plot open
    utils.save_ae_training_plot(fig)
    utils.hold_training_plot()
コード例 #11
0
def main():
    # data loaders
    _, va_loader, _, get_semantic_label = get_train_val_test_loaders(
        num_classes=config('autoencoder.num_classes'))
    dataset = get_data_by_label(va_loader)

    model = Autoencoder(config('autoencoder.ae_repr_dim'))
    criterion = torch.nn.MSELoss()

    # Attempts to restore the latest checkpoint if exists
    print('Loading autoencoder...')
    #model, start_epoch, _ = restore_checkpoint(model,
        #config('autoencoder.checkpoint'))
    #evaluate_autoencoder(dataset, get_semantic_label, model, criterion)

    # Evaluate model
    model = CNN()
    model, start_epoch, _ = restore_checkpoint(model, config('cnn.checkpoint'))
    evaluate_autoencoder(dataset, get_semantic_label, model, criterion)
コード例 #12
0
def main():
    # data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        num_classes=config('challenge.num_classes'))

    # TODO: define model, loss function, and optimizer
    model = Challenge_try()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)
    #

    # Attempts to restore the latest checkpoint if exists
    print('Loading challenge...')
    model, start_epoch, stats = restore_checkpoint(
        model, config('challenge.checkpoint'))

    fig, axes = utils.make_cnn_training_plot(name='Challenge')

    # Evaluate model
    _evaluate_epoch(axes, tr_loader, va_loader, model, criterion, start_epoch,
                    stats)

    # Loop over the entire dataset multiple times
    for epoch in range(start_epoch, config('challenge.num_epochs')):
        # Train model
        _train_epoch(tr_loader, model, criterion, optimizer)

        # Evaluate model
        _evaluate_epoch(axes, tr_loader, va_loader, model, criterion,
                        epoch + 1, stats)

        # Save model parameters
        save_checkpoint(model, epoch + 1, config('challenge.checkpoint'),
                        stats)

    print('Finished Training')

    # Keep plot open
    utils.save_cnn_training_plot(fig, name='challenge')
    utils.hold_training_plot()
コード例 #13
0
def main():
    # data loaders
    _, va_loader, _, get_semantic_label = get_train_val_test_loaders(
        num_classes=config('autoencoder.num_classes'))
    dataset = get_data_by_label(va_loader)

    model = Autoencoder(config('autoencoder.ae_repr_dim'))

    criterion = torch.nn.MSELoss()

    # Attempts to restore the latest checkpoint if exists
    print('Loading autoencoder...')
    model, start_epoch, _ = restore_checkpoint(
        model, config('autoencoder.checkpoint'))

    # Evaluate model
    evaluate_autoencoder(dataset, get_semantic_label, model, criterion)
    criterion = metrics.accuracy_score()
    evaluate_autoencoder(dataset, get_semantic_label, model, criterion)

    # Report performance
    report_validation_performance(dataset, get_semantic_label, model,
                                  criterion)
コード例 #14
0
def main():
    """Train CNN and show training plots."""
    # Data loaders
    if check_for_augmented_data("./data"):
        tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
            task="target", batch_size=config("cnn.batch_size"), augment=True)
    else:
        tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
            task="target",
            batch_size=config("cnn.batch_size"),
        )
    # Model
    model = Target()

    # TODO: define loss function, and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    #

    print("Number of float-valued parameters:", count_parameters(model))

    # Attempts to restore the latest checkpoint if exists
    print("Loading cnn...")
    model, start_epoch, stats = restore_checkpoint(model,
                                                   config("cnn.checkpoint"))

    axes = utils.make_training_plot()

    # Evaluate the randomly initialized model
    evaluate_epoch(axes, tr_loader, va_loader, te_loader, model, criterion,
                   start_epoch, stats)

    # initial val loss for early stopping
    prev_val_loss = stats[0][1]

    # TODO: define patience for early stopping
    patience = 5
    curr_patience = 0
    #

    # Loop over the entire dataset multiple times
    # for epoch in range(start_epoch, config('cnn.num_epochs')):
    epoch = start_epoch
    while curr_patience < patience:
        # Train model
        train_epoch(tr_loader, model, criterion, optimizer)

        # Evaluate model
        evaluate_epoch(axes, tr_loader, va_loader, te_loader, model, criterion,
                       epoch + 1, stats)

        # Save model parameters
        save_checkpoint(model, epoch + 1, config("cnn.checkpoint"), stats)

        # update early stopping parameters
        curr_patience, prev_val_loss = early_stopping(stats, curr_patience,
                                                      prev_val_loss)

        epoch += 1
    print("Finished Training")
    # Save figure and keep plot open
    utils.save_cnn_training_plot()
    utils.hold_training_plot()
コード例 #15
0
def main():
    # Data loaders
    if check_for_augmented_data("./data"):
        tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
            task="augment",
            batch_size=config("challenge.batch_size"),
        )
    else:
        tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
            task="target",
            batch_size=config("challenge.batch_size"),
        )
    # Model
    model = Challenge()

    # TODO: define loss function, and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    #

    # Attempts to restore the latest checkpoint if exists
    print("Loading challenge...")
    model, start_epoch, stats = restore_checkpoint(model, config("challenge.checkpoint"))

    axes = utils.make_cnn_training_plot()

    # Evaluate the randomly initialized model
    _evaluate_epoch(
        axes, tr_loader, va_loader, te_loader, model, criterion, start_epoch, stats
    )

    # initial val loss for early stopping
    prev_val_loss = stats[0][1]

    #TODO: define patience for early stopping
    patience = 5
    curr_patience = 0
    #

    # Loop over the entire dataset multiple times
    epoch = start_epoch
    while curr_patience < patience:
        # Train model
        _train_epoch(tr_loader, model, criterion, optimizer)

        # Evaluate model
        _evaluate_epoch(
            axes, tr_loader, va_loader, te_loader, model, criterion, epoch + 1, stats
        )

        # Save model parameters
        save_checkpoint(model, epoch + 1, config("challenge.checkpoint"), stats)

        #TODO: Implement early stopping
        curr_patience, prev_val_loss = early_stopping(
            stats, curr_patience, prev_val_loss
        )
        #
        epoch += 1
    print("Finished Training")
    # Save figure and keep plot open
    utils.save_challenge_training_plot()
    utils.hold_training_plot()
コード例 #16
0
    _ = gcam.forward(xi)
    gcam.backward(ids=torch.tensor([[target_class]]).to(device))
    regions = gcam.generate(target_layer=target_layer)
    activation = regions.detach()
    save_gradcam(
        np.squeeze(activation),
        utils.denormalize_image(np.squeeze(xi.numpy()).transpose(1, 2, 0)),
        axarr,
        i,
    )


if __name__ == "__main__":
    # Attempts to restore from checkpoint
    print("Loading cnn...")
    model = Source()
    model, start_epoch, _ = restore_checkpoint(model,
                                               config("net.checkpoint"),
                                               force=True)

    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        task="target",
        batch_size=config("net.batch_size"),
    )
    for i in range(40):
        plt.clf()
        f, axarr = plt.subplots(1, 2)
        visualize_input(i, axarr)
        visualize_layer1_activations(i, axarr)
        plt.close()
コード例 #17
0
    zi = zi[sort_mask]
    fig, axes = plt.subplots(4, 4, figsize=(10,10))
    for i, ax in enumerate(axes.ravel()):
        ax.axis('off')
        im = ax.imshow(zi[i], cmap='gray')
    fig.suptitle('Layer 1 activations, y={}'.format(yi))
    fig.savefig('CNN_viz1_{}.png'.format(yi), dpi=200, bbox_inches='tight')

if __name__ == '__main__':
    # Attempts to restore from checkpoint
    print('Loading cnn...')
    model = CNN()
    model, start_epoch, _ = restore_checkpoint(model, config('cnn.checkpoint'),
        force=True)

    tr_loader, _, _, _ = get_train_val_test_loaders(
        num_classes=config('cnn.num_classes'))

    # Saving input images in original resolution
    metadata = pd.read_csv(config('csv_file'))
    for idx in [0, 4, 14, 15, 21]:
        filename = os.path.join(
            config('image_path'), metadata.loc[idx, 'filename'])
        plt.imshow(imread(filename))
        plt.axis('off')
        plt.savefig('CNN_viz0_{}.png'.format(int(
            metadata.loc[idx, 'numeric_label'])),
            dpi=200, bbox_inches='tight')

    # Saving layer activations
    for i in [0, 2, 5, 6, 9]:
        visualize_layer1_activations(i)
コード例 #18
0
def main():
    filename = config("savefilename")
    lr = 0.0001

    this_config = dict(csv_file=config("csv_file"),
                       img_path=config("image_path"),
                       learning_rate=lr,
                       num_classes=4,
                       batchsize=64)

    wandb.init(project="prob_fix", name=filename, config=this_config)

    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        task="default", batch_size=config("net.batch_size"))

    print('successfully loading!')

    model = Source()
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    print("Number of float-valued parameters:", count_parameters(model))

    model, start_epoch, stats = restore_checkpoint(model,
                                                   config("cnn.checkpoint"))

    axes = utils.make_training_plot()
    prolist = []

    evaluate_epoch(axes,
                   tr_loader,
                   va_loader,
                   te_loader,
                   model,
                   criterion,
                   start_epoch,
                   stats,
                   prolist,
                   multiclass=True)

    # initial val loss for early stopping
    prev_val_loss = stats[0][1]

    # TODO: define patience for early stopping
    patience = 5
    curr_patience = 0
    #

    # Loop over the entire dataset multiple times
    # for epoch in range(start_epoch, config('cnn.num_epochs')):
    epoch = start_epoch

    lowest_val_loss = 1
    train_auroc = 0
    test_auroc = 0
    lowest_round = epoch
    while curr_patience < patience:
        # Train model
        train_epoch(tr_loader, model, criterion, optimizer)

        # Evaluate model
        evaluate_epoch(axes,
                       tr_loader,
                       va_loader,
                       te_loader,
                       model,
                       criterion,
                       epoch + 1,
                       stats,
                       prolist,
                       multiclass=True)

        # Save model parameters
        save_checkpoint(model, epoch + 1, config("net.checkpoint"), stats)

        # update early stopping parameters
        curr_patience, prev_val_loss = early_stopping(stats, curr_patience,
                                                      prev_val_loss)

        epoch += 1
        if (prev_val_loss < lowest_val_loss):
            lowest_val_loss = prev_val_loss
            lowest_round = epoch

    pickle.dump(prolist, open("base_pro.pck", "wb"))
    print("Finished Training")
    # Save figure and keep plot open
    print("the lowest round: ", lowest_round)
    # utils.save_cnn_training_plot()
    # utils.save_cnn_other()
    utils.hold_training_plot()
コード例 #19
0
def main():
    # data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        num_classes=config('autoencoder.classifier.num_classes'))

    ae_classifier = AutoencoderClassifier(config('autoencoder.ae_repr_dim'),
        config('autoencoder.classifier.num_classes'))
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(ae_classifier.parameters(),
        lr=config('autoencoder.classifier.learning_rate'))

    # freeze the weights of the encoder
    for name, param in ae_classifier.named_parameters():
        if 'fc1.' in name or 'fc2.' in name:
            param.requires_grad = False

    # Attempts to restore the latest checkpoint if exists
    print('Loading autoencoder...')
    ae_classifier, _, _ = restore_checkpoint(ae_classifier,
        config('autoencoder.checkpoint'), force=True, pretrain=True)
    print('Loading autoencoder classifier...')
    ae_classifier, start_epoch, stats = restore_checkpoint(ae_classifier,
        config('autoencoder.classifier.checkpoint'))

    fig, axes = utils.make_cnn_training_plot(name='Autoencoder Classifier')

    # Evaluate the randomly initialized model
    _evaluate_epoch(axes, tr_loader, va_loader, ae_classifier, criterion,
        start_epoch, stats)

    # Loop over the entire dataset multiple times
    for epoch in range(start_epoch, config('autoencoder.classifier.num_epochs')):
        # Train model
        _train_epoch(tr_loader, ae_classifier, criterion, optimizer)

        # Evaluate model
        _evaluate_epoch(axes, tr_loader, va_loader, ae_classifier, criterion,
            epoch+1, stats)

        #accuracy
        if epoch == start_epoch:
            r = [[], [], [], [], []]
            for X, y in va_loader:
                with torch.no_grad():
                    output = ae_classifier(X)
                    predict_res = predictions(output.data)
                    for y_sub, pred_out in zip(y, predict_res):
                        r[y_sub.item()].append(pred_out == y_sub)

            for i in range(0,5):
                print("Class ", i, "gives accuracy", np.sum(np.array(r[i]) / len(r[i])))

        # Save model parameters
        save_checkpoint(ae_classifier, epoch+1,
            config('autoencoder.classifier.checkpoint'), stats)

    print('Finished Training')

    # Keep plot open
    utils.save_cnn_training_plot(fig, name='ae_clf')
    utils.hold_training_plot()
コード例 #20
0
def main():
    """Train source model on multiclass data."""
    # Data loaders
    tr_loader, va_loader, te_loader, _ = get_train_val_test_loaders(
        task="source",
        batch_size=config("source.batch_size"),
    )

    # Model
    model = Source()

    # TODO: define loss function, and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-3,
                                 weight_decay=0.01)
    #

    print("Number of float-valued parameters:", count_parameters(model))

    # Attempts to restore the latest checkpoint if exists
    print("Loading source...")
    model, start_epoch, stats = restore_checkpoint(model,
                                                   config("source.checkpoint"))

    axes = utils.make_training_plot("Source Training")

    # Evaluate the randomly initialized model
    evaluate_epoch(
        axes,
        tr_loader,
        va_loader,
        te_loader,
        model,
        criterion,
        start_epoch,
        stats,
        multiclass=True,
    )

    # initial val loss for early stopping
    prev_val_loss = stats[0][1]

    # TODO: patience for early stopping
    patience = 10
    curr_patience = 0
    #

    # Loop over the entire dataset multiple times
    epoch = start_epoch
    while curr_patience < patience:
        # Train model
        train_epoch(tr_loader, model, criterion, optimizer)

        # Evaluate model
        evaluate_epoch(
            axes,
            tr_loader,
            va_loader,
            te_loader,
            model,
            criterion,
            epoch + 1,
            stats,
            multiclass=True,
        )

        # Save model parameters
        save_checkpoint(model, epoch + 1, config("source.checkpoint"), stats)

        curr_patience, prev_val_loss = early_stopping(stats, curr_patience,
                                                      prev_val_loss)
        epoch += 1

    # Save figure and keep plot open
    print("Finished Training")
    utils.save_source_training_plot()
    utils.hold_training_plot()