Exemple #1
0
def test_model(test):
    with tf.Session() as sess:

        saver.restore(sess, "./model_dir/model1/model.ckpt")
        print("Model restored.")

        predictions = []
        labels = []

        for idx, row in test.iterrows():

            if not row["content"]:
                continue

            feed_dict, label = get_feed_dict(row)

            predicted_logits = sess.run(logits, feed_dict=feed_dict)

            predicted = util.normalize_predictions(
                predicted_logits[0][0][1:-1])

            print(row["content"][:50], "\n", "Actual:", label, "\nPredicted:",
                  predicted, "\n")

            predictions.append(predicted)
            labels.append(label)

    util.print_summary(labels, predictions)
def main(unused_argv):
  # Build data and .
  print('Loading data.')
  x_train, y_train, x_test, y_test = datasets.mnist(permute_train=True)

  # Build the network
  init_fn, f = stax.serial(
      stax.Dense(2048),
      stax.Tanh,
      stax.Dense(10))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Linearize the network about its initial parameters.
  f_lin = linearize(f, params)

  # Create and initialize an optimizer for both f and f_lin.
  opt_init, opt_apply, get_params = optimizers.momentum(FLAGS.learning_rate,
                                                        0.9)
  opt_apply = jit(opt_apply)

  state = opt_init(params)
  state_lin = opt_init(params)

  # Create a cross-entropy loss function.
  loss = lambda fx, y_hat: -np.mean(stax.logsoftmax(fx) * y_hat)

  # Specialize the loss function to compute gradients for both linearized and
  # full networks.
  grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
  grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

  # Train the network.
  print('Training.')
  print('Epoch\tLoss\tLinearized Loss')
  print('------------------------------------------')

  epoch = 0
  steps_per_epoch = 50000 // FLAGS.batch_size

  for i, (x, y) in enumerate(datasets.minibatch(
      x_train, y_train, FLAGS.batch_size, FLAGS.train_epochs)):

    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x, y), state)

    params_lin = get_params(state_lin)
    state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

    if i % steps_per_epoch == 0:
      print('{}\t{:.4f}\t{:.4f}'.format(
          epoch, loss(f(params, x), y), loss(f_lin(params_lin, x), y)))
      epoch += 1

  # Print out summary data comparing the linear / nonlinear model.
  x, y = x_train[:10000], y_train[:10000]
  util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
  util.print_summary(
      'test', y_test, f(params, x_test), f_lin(params_lin, x_test), loss)
def main(unused_argv):
  # Build data pipelines.
  print('Loading data.')
  x_train, y_train, x_test, y_test = \
      datasets.mnist(FLAGS.train_size, FLAGS.test_size)

  # Build the network
  init_fn, f, _ = stax.serial(
      stax.Dense(4096, 1., 0.),
      stax.Erf(),
      stax.Dense(10, 1., 0.))

  key = random.PRNGKey(0)
  _, params = init_fn(key, (-1, 784))

  # Create and initialize an optimizer.
  opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
  state = opt_init(params)

  # Create an mse loss function and a gradient function.
  loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
  grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

  # Create an MSE predictor to solve the NTK equation in function space.
  ntk = batch(get_ntk_fun_empirical(f), batch_size=32, device_count=1)
  g_dd = ntk(x_train, None, params)
  g_td = ntk(x_test, x_train, params)
  predictor = predict.analytic_mse(g_dd, y_train, g_td)

  # Get initial values of the network in function space.
  fx_train = f(params, x_train)
  fx_test = f(params, x_test)

  # Train the network.
  train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
  print('Training for {} steps'.format(train_steps))

  for i in range(train_steps):
    params = get_params(state)
    state = opt_apply(i, grad_loss(params, x_train, y_train), state)

  # Get predictions from analytic computation.
  print('Computing analytic prediction.')
  fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

  # Print out summary data comparing the linear / nonlinear model.
  util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
  util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
Exemple #4
0
def mm_pair_valid(device, batch_size, a_m, i_m, cls, testloader, epoch,
                  writer):
    a_m.eval()
    i_m.eval()
    cls.eval()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = ['loss', 'correct', 'nums', 'accuracy']
    val_metrics = MetricTracker(*[m for m in metric_ftns],
                                writer=writer,
                                mode='val')
    val_metrics.reset()
    confusion_matrix = torch.zeros(2, 2)
    with torch.no_grad():
        for batch_idx, (audio, img, label) in enumerate(testloader):

            audio, img = audio.to(device), img.to(device)
            label = label.to(device)
            i_output, i_feature = i_m(img)
            a_output, a_feature = a_m(audio)
            concat_output = cls(torch.cat([i_feature, a_feature], 1))

            loss = criterion(concat_output, label)

            correct, nums, acc = accuracy(concat_output, label)
            num_samples = batch_idx * batch_size + 1
            _, preds = torch.max(concat_output, 1)
            for t, p in zip(label.cpu().view(-1), preds.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            val_metrics.update_all_metrics(
                {
                    'correct': correct,
                    'nums': nums,
                    'loss': loss.item(),
                    'accuracy': acc
                },
                writer_step=(epoch - 1) * len(testloader) + batch_idx)

    num_samples += len(label) - 1
    print_summary(epoch, num_samples, val_metrics, mode="Validation")

    print('Confusion Matrix\n{}'.format(confusion_matrix.cpu().numpy()))
    return val_metrics, confusion_matrix
Exemple #5
0
def train(device, batch_size, model, trainloader, optimizer, epoch, writer):
    model.train()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = ['loss', 'correct', 'nums', 'accuracy']
    train_metrics = MetricTracker(*[m for m in metric_ftns],
                                  writer=writer,
                                  mode='train')
    train_metrics.reset()
    confusion_matrix = torch.zeros(2, 2)

    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        input_data = input_data.to(device)
        target = target.to(device)

        output = model(input_data)

        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        correct, nums, acc = accuracy(output, target)
        num_samples = batch_idx * batch_size + 1
        _, preds = torch.max(output, 1)
        for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):
            confusion_matrix[t.long(), p.long()] += 1
        train_metrics.update_all_metrics(
            {
                'correct': correct,
                'nums': nums,
                'loss': loss.item(),
                'accuracy': acc
            },
            writer_step=(epoch - 1) * len(trainloader) + batch_idx)
        print_stats(epoch, batch_size, num_samples, trainloader, train_metrics)
    num_samples += len(target) - 1

    print_summary(epoch, num_samples, train_metrics, mode="Training")
    print('Confusion Matrix\n{}\n'.format(confusion_matrix.cpu().numpy()))
    return train_metrics
Exemple #6
0
def validation(device, batch_size, classes, model, testloader, epoch, writer):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = ['loss', 'correct', 'nums', 'accuracy']
    val_metrics = MetricTracker(*[m for m in metric_ftns],
                                writer=writer,
                                mode='val')
    val_metrics.reset()
    confusion_matrix = torch.zeros(classes, classes)
    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors
            input_data = input_data.to(device)
            target = target.to(device)

            output, _ = model(input_data)

            loss = criterion(output, target)

            correct, nums, acc = accuracy(output, target)
            num_samples = batch_idx * batch_size + 1
            _, preds = torch.max(output, 1)
            for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            val_metrics.update_all_metrics(
                {
                    'correct': correct,
                    'nums': nums,
                    'loss': loss.item(),
                    'accuracy': acc
                },
                writer_step=(epoch - 1) * len(testloader) + batch_idx)

    num_samples += len(target) - 1
    print_summary(epoch, num_samples, val_metrics, mode="Validation")

    print('Confusion Matrix\n{}'.format(confusion_matrix.cpu().numpy()))
    return val_metrics, confusion_matrix
Exemple #7
0
import sys
import pandas as pd
from util import get_summary_for_column, print_summary

df = pd.read_csv('./distro_output.csv')
# Common things to search... "Store", "Artist"

if len(sys.argv) == 2:
    column_name = sys.argv[1]
    column_titles = list(df.columns)
    # These column names are used in the other program already for calculating
    # the summary statistics.
    column_titles.remove("Earnings (USD)")
    column_titles.remove("Quantity")
    if column_name not in column_titles:
        sys.exit(
            f"Column name argument not valid! Column names are {column_titles}"
        )
else:
    sys.exit(
        "Need exactly one argument after filename! like\npython summarize.py Artist"
    )

# Valid DataFrame, valid args, continue...
songs_dict = get_summary_for_column(column_name="Title", df=df)
print_summary(songs_dict)
Exemple #8
0
def main(unused_argv):
    # print(f'Available GPU memory: {util.get_gpu_memory()}')
    # Load and normalize data
    print('Loading data...')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            n_train=60000,
                                                            n_test=10000,
                                                            permute_train=True)
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Reformat MNIST data to 28x28x1 pictures
    x_train = np.asarray(x_train.reshape(-1, 28, 28, 1))
    x_test = np.asarray(x_test.reshape(-1, 28, 28, 1))
    print('Data loaded and reshaped')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Set random seed
    key = random.PRNGKey(0)

    # # Add random translation to images
    # x_train = util.add_translation(x_train, FLAGS.max_pixel)
    # x_test = util.add_translation(x_test, FLAGS.max_pixel)
    # print(f'Random translation by up to {FLAGS.max_pixel} pixels added')

    # # Add random translations with padding
    # x_train = util.add_padded_translation(x_train, 10)
    # x_test = util.add_padded_translation(x_test, 10)
    # print(f'Random translations with additional padding up to 10 pixels added')

    # Build the LeNet network with NTK parameterization
    init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width)
    print(f'Network of width x{FLAGS.network_width} built.')

    # # Construct the kernel function
    # kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel)
    # print('Kernel constructed')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Compute random initial parameters
    _, params = init_fn(key, (-1, 28, 28, 1))
    params_lin = params

    print('Initial parameters constructed')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # # Save initial parameters
    # with open('init_params.npy', 'wb') as file:
    #     np.save(file, params)

    # Linearize the network about its initial parameters.
    # Use jit for faster GPU computation (only feasible for width < 25)
    f_lin = nt.linearize(f, params)
    if FLAGS.network_width <= 10:
        f_jit = jit(f)
        f_lin_jit = jit(f_lin)
    else:
        f_jit = f
        f_lin_jit = f_lin

    # Create a callable function for dynamic learning rates
    # Starts with learning_rate, divided by 10 after learning_decline epochs.
    dynamic_learning_rate = lambda iteration_step: FLAGS.learning_rate / 10**(
        (iteration_step //
         (x_train.shape[0] // FLAGS.batch_size)) // FLAGS.learning_decline)

    # Create and initialize an optimizer for both f and f_lin.
    # Use momentum with coefficient 0.9 and jit
    opt_init, opt_apply, get_params = optimizers.momentum(
        dynamic_learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    # Compute the initial states
    state = opt_init(params)
    state_lin = opt_init(params)

    # Define the accuracy function
    accuracy = lambda fx, y_hat: np.mean(
        np.argmax(fx, axis=1) == np.argmax(y_hat, axis=1))

    # Define mean square error loss function
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)

    # # Create a cross-entropy loss function.
    # loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print(
        f'Training with dynamic learning decline after {FLAGS.learning_decline} epochs...'
    )
    print(
        'Epoch\tTime\tAccuracy\tLin. Accuracy\tLoss\tLin. Loss\tAccuracy Train\tLin.Accuracy Train'
    )
    print(
        '----------------------------------------------------------------------------------------------------------'
    )

    # Initialize training
    epoch = 0
    steps_per_epoch = x_train.shape[0] // FLAGS.batch_size

    # Set start time (total and 100 epochs)
    start = time.time()
    start_epoch = time.time()

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        # Update the parameters
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        # Print information after each 100 epochs
        if (i + 1) % (steps_per_epoch * 100) == 0:
            time_point = time.time() - start_epoch

            # Update epoch
            epoch += 100

            # Accuracy in batches
            f_x = util.output_in_batches(x_train, params, f_jit,
                                         FLAGS.batch_count_accuracy)
            f_x_test = util.output_in_batches(x_test, params, f_jit,
                                              FLAGS.batch_count_accuracy)
            f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit,
                                             FLAGS.batch_count_accuracy)
            f_x_lin_test = util.output_in_batches(x_test, params_lin,
                                                  f_lin_jit,
                                                  FLAGS.batch_count_accuracy)
            # time_point = time.time() - start_epoch

            # Print information about past 100 epochs
            print(
                '{}\t{:.3f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}'
                .format(epoch, time_point,
                        accuracy(f_x, y_train) * 100,
                        accuracy(f_x_lin, y_train) * 100, loss(f_x, y_train),
                        loss(f_x_lin, y_train),
                        accuracy(f_x_test, y_test) * 100,
                        accuracy(f_x_lin_test, y_test) * 100))

            # # Save params if epoch is multiple of learning decline or multiple of fixed value
            # if epoch % FLAGS.learning_decline == 0:
            #     filename = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}.npy'
            #     with open(filename, 'wb') as file:
            #         np.save(file, params)
            #     filename_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}_lin.npy'
            #     with open(filename_lin, 'wb') as file_lin:
            #         np.save(file_lin, params_lin)

            # Reset timer
            start_epoch = time.time()

    duration = time.time() - start
    print(
        '----------------------------------------------------------------------------------------------------------'
    )
    print(f'Training complete in {duration} seconds.')

    # # Save final params in file
    # filename_final = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}.npy '
    # with open(filename_final, 'wb') as final:
    #     np.save(final, params)
    # filename_final_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}_lin.npy'
    # with open(filename_final_lin, 'wb') as final_lin:
    #     np.save(final_lin, params_lin)

    # Compute output in batches
    f_x = util.output_in_batches(x_train, params, f_jit,
                                 FLAGS.batch_count_accuracy)
    f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit,
                                     FLAGS.batch_count_accuracy)

    f_x_test = util.output_in_batches(x_test, params, f_jit,
                                      FLAGS.batch_count_accuracy)
    f_x_lin_test = util.output_in_batches(x_test, params_lin, f_lin_jit,
                                          FLAGS.batch_count_accuracy)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f_x, f_x_lin, loss)
    util.print_summary('test', y_test, f_x_test, f_x_lin_test, loss)
Exemple #9
0
def main(unused_argv):
    # Load and normalize data
    print('Loading data...')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist', n_train=10, n_test=10,
                                                            permute_train=True)

    # Reformat MNIST data to 28x28x1 pictures
    x_train = np.asarray(x_train.reshape(-1, 28, 28, 1))
    x_test = np.asarray(x_test.reshape(-1, 28, 28, 1))
    print(f'Data loaded and reshaped with n_train = {x_train.shape[0]} (batch size {FLAGS.batch_size_kernel}) and '
          f'n_test = {x_test.shape[0]}.')

    # # Add random translation to images
    # x_train = util.add_translation(x_train, FLAGS.max_pixel)
    # x_test = util.add_translation(x_test, FLAGS.max_pixel)
    # print(f'Random translations by up to {FLAGS.max_pixel} pixels added')

    # # Add random translations with padding
    # x_train = util.add_padded_translation(x_train, 10)
    # x_test = util.add_padded_translation(x_test, 10)
    # print(f'Random translations with additional padding up to 10 pixels added')


    # Build the LeNet network
    init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width)
    print('Network build complete')

    # Construct the kernel function
    # Use 'store_on_device = False' for larger kernels
    kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel, store_on_device=False)

    # Set start time
    start_inf = time.time()

    # Bayesian and infinite-time gradient descent inference with infinite network
    print('Starting bayesian and infinite-time gradient descent inference with infinite network')
    predict_fn = nt.predict.gradient_descent_mse_ensemble(
        kernel_fn=kernel_fn,
        x_train=x_train,
        y_train=y_train,
        diag_reg=1e-6
    )

    duration_kernel = time.time() - start_inf
    print(f'Kernel constructed in {duration_kernel} seconds.')

    # fx_test_nngp_ub, fx_test_ntk_ub = predict_fn(x_test=x_test, get=('nngp', 'ntk'))

    fx_test_nngp, fx_test_ntk = [] * x_test.shape[0], [] * x_test.shape[0]
    print('Output vector allocated.')
    # print(f'Available GPU memory: {util.get_gpu_memory()} MiB')

    # Compute predictions in batches
    for i in range(x_test.shape[0] // FLAGS.batch_size_output):
        time_batch = time.time()
        start, end = i * FLAGS.batch_size_output, (i+1) * FLAGS.batch_size_output
        x = x_test[start:end]
        tmp_nngp, tmp_ntk = predict_fn(x_test=x, get=('nngp', 'ntk'))
        # tmp_ntk = predict_fn(x_test=x, get='ntk')
        duration_batch = time.time() - time_batch
        print(f'Batch {i+1} predicted in {duration_batch} seconds.')
        # print(f'Available GPU memory: {util.get_gpu_memory()} MiB')
        fx_test_nngp[start:end] = tmp_nngp
        fx_test_ntk[start:end] = tmp_ntk

    fx_test_nngp = np.array(fx_test_nngp)
    fx_test_ntk = np.array(fx_test_ntk)

    # fx_test_nngp.block_until_ready()
    # fx_test_ntk.block_until_ready()

    duration_inf = time.time() - start_inf

    print(f'Inference done in {duration_inf} seconds.')

    # Print out accuracy and loss for infinite network predictions.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
    util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
    util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)
Exemple #10
0
def mm_train(device, batch_size, a_m, i_m, trainloader, optimizer, epoch,
             writer):
    a_m.train()
    i_m.train()
    weight = torch.tensor([0.1, 0.9]).to(device)
    ce_loss = nn.CrossEntropyLoss(weight=weight, reduction='mean')
    alpha = 1

    metric_ftns = ['loss', 'correct', 'nums', 'accuracy']
    image_metrics = MetricTracker(*[m for m in metric_ftns],
                                  writer=writer,
                                  mode='train')
    audio_metrics = MetricTracker(*[m for m in metric_ftns],
                                  writer=writer,
                                  mode='train')
    image_metrics.reset()
    audio_metrics.reset()
    i_confusion_matrix = torch.zeros(2, 2)
    a_confusion_matrix = torch.zeros(2, 2)

    for batch_idx, (audio, a_label, img, i_label) in enumerate(trainloader):

        audio, img = audio.to(device), img.to(device)
        a_label, i_label = a_label.to(device), i_label.to(device)
        i_output, i_feature = i_m(img)
        a_output, a_feature = a_m(audio)

        i_ce = ce_loss(i_output, i_label)
        a_ce = ce_loss(a_output, a_label)
        csa = csa_loss(a_feature, i_feature.detach(),
                       (a_label == i_label).float())

        loss = i_ce + 0.4 * a_ce + alpha * csa
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        num_samples = batch_idx * batch_size + 1
        # image
        i_correct, i_nums, i_acc = accuracy(i_output, i_label)
        image_metrics.update_all_metrics(
            {
                'correct': i_correct,
                'nums': i_nums,
                'loss': loss.item(),
                'accuracy': i_acc
            },
            writer_step=(epoch - 1) * len(trainloader) + batch_idx)
        # audio
        a_correct, a_nums, a_acc = accuracy(a_output, a_label)
        audio_metrics.update_all_metrics(
            {
                'correct': a_correct,
                'nums': a_nums,
                'loss': loss.item(),
                'accuracy': a_acc
            },
            writer_step=(epoch - 1) * len(trainloader) + batch_idx)
        _, preds = torch.max(a_output, 1)
        for t, p in zip(a_label.cpu().view(-1), preds.cpu().view(-1)):
            a_confusion_matrix[t.long(), p.long()] += 1
        print_stats(epoch,
                    batch_size,
                    num_samples,
                    trainloader,
                    image_metrics,
                    mode="Image",
                    acc=i_acc)
        print_stats(epoch,
                    batch_size,
                    num_samples,
                    trainloader,
                    audio_metrics,
                    mode="Audio",
                    acc=a_acc)
    num_samples += len(a_output) - 1

    print_summary(epoch, num_samples, image_metrics, mode="Training Image")
    print_summary(epoch, num_samples, audio_metrics, mode="Training Audio")
    print('A_Confusion Matrix\n{}\n'.format(a_confusion_matrix.cpu().numpy()))

    return audio_metrics
Exemple #11
0
def train(args, model, trainloader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss(reduction='elementwise_mean')

    metrics = Metrics('')
    metrics.reset()
    batch_idx = 0
    for input_tensors in tqdm(trainloader):
        batch_idx = batch_idx + 1
        optimizer.zero_grad()
        input_data, target, site = input_tensors
        if args.cuda:
            input_data = input_data.cuda()
            target = target.cuda()

        ucsd_input = torch.from_numpy(np.array([]))
        new_input = torch.from_numpy(np.array([]))
        ucsd_label = torch.from_numpy(np.array([]))
        new_label = torch.from_numpy(np.array([]))

        for i in range(len(input_data)):
            if site[i] == 'ucsd':
                if len(ucsd_input) == 0:
                    ucsd_input = input_data[i].unsqueeze(0)
                    ucsd_label = torch.from_numpy(np.array([target[i]]))
                else:
                    ucsd_input = torch.cat((ucsd_input, input_data[i].unsqueeze(0)))
                    ucsd_label = torch.cat((ucsd_label, torch.from_numpy(np.array([target[i]]))))
            else:
                if len(new_input) == 0:
                    new_input = input_data[i].unsqueeze(0)
                    new_label = torch.from_numpy(np.array([target[i]]))
                else:
                    new_input = torch.cat((new_input, input_data[i].unsqueeze(0)))
                    new_label = torch.cat((new_label, torch.from_numpy(np.array([target[i]]))))

        if len(ucsd_input) > 1:
            ucsd_output, ucsd_features = model(ucsd_input, 'ucsd')
        if len(new_input) > 1:
            new_output, new_features = model(new_input, 'ucsd')

        if len(ucsd_input) > 1 and len(new_input) > 1:
            output = torch.cat((ucsd_output, new_output))
            labels = torch.cat((ucsd_label, new_label)).cuda()
            features = torch.cat((ucsd_features, new_features))
        elif len(ucsd_input) > 1 and len(new_input) < 2:
            output = ucsd_output
            labels = ucsd_label.cuda()
            features = ucsd_features
        else:
            output = new_output
            labels = new_label.cuda()
            features = new_features

        if len(output) != len(labels):
            continue

        if len(features) == 32 and args.cont:
            temperature = 0.05
            cont_loss_func = losses.NTXentLoss(temperature)
            cont_loss = cont_loss_func(features, labels)
            loss = criterion(output, labels) + cont_loss
        else:
            loss = criterion(output, labels)

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()
        correct, total, acc = accuracy(output, labels)
        top1_correct = top_k_acc(output, labels, k=1)
        top3_correct = top_k_acc(output, labels, k=2)

        metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc,
                        'top1_correct': top1_correct, 'top3_correct': top3_correct})

    print_summary(args, epoch, metrics, mode="train")
    return metrics
Exemple #12
0
def validation(args, model, testloader, epoch, mode='test'):
    conf_matrix = torch.zeros(args.classes, args.classes).cuda()
    model.eval()
    criterion = nn.CrossEntropyLoss()

    metrics = Metrics('')
    metrics.reset()
    batch_idx = 0
    ucsd_correct_total = 0
    sars_correct_total = 0
    ucsd_test_total = 0
    sars_test_total = 0
    with torch.no_grad():
        for input_tensors in tqdm(testloader):
            batch_idx = batch_idx + 1
            input_data, target, site = input_tensors
            if args.cuda:
                input_data = input_data.cuda()
                target = target.cuda()

            ucsd_input = torch.from_numpy(np.array([]))
            new_input = torch.from_numpy(np.array([]))
            ucsd_label = torch.from_numpy(np.array([]))
            new_label = torch.from_numpy(np.array([]))

            for i in range(len(input_data)):
                if site[i] == 'ucsd':
                    if len(ucsd_input) == 0:
                        ucsd_input = input_data[i].unsqueeze(0)
                        ucsd_label = torch.from_numpy(np.array([target[i]]))
                    else:
                        ucsd_input = torch.cat((ucsd_input, input_data[i].unsqueeze(0)))
                        ucsd_label = torch.cat((ucsd_label, torch.from_numpy(np.array([target[i]]))))
                else:
                    if len(new_input) == 0:
                        new_input = input_data[i].unsqueeze(0)
                        new_label = torch.from_numpy(np.array([target[i]]))
                    else:
                        new_input = torch.cat((new_input, input_data[i].unsqueeze(0)))
                        new_label = torch.cat((new_label, torch.from_numpy(np.array([target[i]]))))

            if len(ucsd_input) > 1:
                ucsd_output, ucsd_features = model(ucsd_input, 'ucsd')
                ucsd_correct, ucsd_total, ucsd_acc = accuracy(ucsd_output, ucsd_label.cuda())
                ucsd_correct_total += ucsd_correct
                ucsd_test_total += ucsd_total

            if len(new_input) > 1:
                new_output, new_features = model(new_input, 'ucsd')
                sars_correct, sars_total, sars_acc = accuracy(new_output, new_label.cuda())
                sars_correct_total += sars_correct
                sars_test_total += sars_total

            if len(ucsd_input) > 1 and len(new_input) > 1:
                output = torch.cat((ucsd_output, new_output))
                labels = torch.cat((ucsd_label, new_label)).cuda()
                features = torch.cat((ucsd_features, new_features))
            elif len(ucsd_input) > 1 and len(new_input) < 2:
                output = ucsd_output
                labels = ucsd_label.cuda()
            else:
                output = new_output
                labels = new_label.cuda()

            loss = criterion(output, labels)

            preds = torch.argmax(output, dim=1)
            for t, p in zip(target.view(-1), preds.view(-1)):
                conf_matrix[t.long(), p.long()] += 1

            correct, total, acc = accuracy(output, labels)

            # top k acc
            top1_correct = top_k_acc(output, labels, k=1)
            top3_correct = top_k_acc(output, labels, k=2)

            metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc,
                            'top1_correct': top1_correct, 'top3_correct': top3_correct})

    print_summary(args, epoch, metrics, mode="test")

    return metrics, conf_matrix, ucsd_correct_total, sars_correct_total, ucsd_test_total, sars_test_total