Esempio n. 1
0
def test_one_shot(args, model, test_samples=5000, partition='test'):
    io = io_utils.IOStream('checkpoints/' + args.exp_name + '/run.log')

    io.cprint('\n**** TESTING WITH %s ***' % (partition,))

    loader = generator.Generator(args.dataset_root, args, partition=partition, dataset=args.dataset)

    [enc_nn, metric_nn, softmax_module] = model
    enc_nn.eval()
    metric_nn.eval()
    correct = 0
    total = 0
    iterations = int(test_samples/args.batch_size_test)
    for i in range(iterations):
        data = loader.get_task_batch(batch_size=args.batch_size_test, n_way=args.test_N_way,
                                     num_shots=args.test_N_shots, unlabeled_extra=args.unlabeled_extra)
        [x, labels_x_cpu, _, _, xi_s, labels_yi_cpu, oracles_yi, hidden_labels] = data

        if args.cuda:
            xi_s = [batch_xi.cuda() for batch_xi in xi_s]
            labels_yi = [label_yi.cuda() for label_yi in labels_yi_cpu]
            oracles_yi = [oracle_yi.cuda() for oracle_yi in oracles_yi]
            hidden_labels = hidden_labels.cuda()
            x = x.cuda()
        else:
            labels_yi = labels_yi_cpu

        xi_s = [Variable(batch_xi) for batch_xi in xi_s]
        labels_yi = [Variable(label_yi) for label_yi in labels_yi]
        oracles_yi = [Variable(oracle_yi) for oracle_yi in oracles_yi]
        hidden_labels = Variable(hidden_labels)
        x = Variable(x)

        # Compute embedding from x and xi_s
        z = enc_nn(x)[-1]
        zi_s = [enc_nn(batch_xi)[-1] for batch_xi in xi_s]

        # Compute metric from embeddings
        output, out_logits = metric_nn(inputs=[z, zi_s, labels_yi, oracles_yi, hidden_labels])
        output = out_logits
        y_pred = softmax_module.forward(output)
        y_pred = y_pred.data.cpu().numpy()
        y_pred = np.argmax(y_pred, axis=1)
        labels_x_cpu = labels_x_cpu.numpy()
        labels_x_cpu = np.argmax(labels_x_cpu, axis=1)

        for row_i in range(y_pred.shape[0]):
            if y_pred[row_i] == labels_x_cpu[row_i]:
                correct += 1
            total += 1

        if (i+1) % 100 == 0:
            io.cprint('{} correct from {} \tAccuracy: {:.3f}%)'.format(correct, total, 100.0*correct/total))

    io.cprint('{} correct from {} \tAccuracy: {:.3f}%)'.format(correct, total, 100.0*correct/total))
    io.cprint('*** TEST FINISHED ***\n'.format(correct, total, 100.0 * correct / total))
    enc_nn.train()
    metric_nn.train()

    return 100.0 * correct / total
Esempio n. 2
0
def train():
    train_loader = generator.Generator(args.dataset_root,
                                       args,
                                       partition='train',
                                       dataset=args.dataset)
    logger.info('Batch size: ' + str(args.batch_size))

    #Try to load models
    enc_nn = models.load_model('enc_nn', args)
    metric_nn = models.load_model('metric_nn', args)

    if enc_nn is None or metric_nn is None:
        enc_nn, metric_nn = models.create_models(args=args)
    softmax_module = models.SoftmaxModule()

    if args.cuda:
        enc_nn.cuda()
        metric_nn.cuda()

    logger.info(str(enc_nn))
    logger.info(str(metric_nn))

    weight_decay = 0
    if args.dataset == 'mini_imagenet':
        logger.info('Weight decay ' + str(1e-6))
        weight_decay = 1e-6
    opt_enc_nn = optim.Adam(enc_nn.parameters(),
                            lr=args.lr,
                            weight_decay=weight_decay)
    opt_metric_nn = optim.Adam(metric_nn.parameters(),
                               lr=args.lr,
                               weight_decay=weight_decay)

    model_summary([enc_nn, metric_nn])
    optimizer_summary([opt_enc_nn, opt_metric_nn])
    enc_nn.train()
    metric_nn.train()
    counter = 0
    total_loss = 0
    val_acc, val_acc_aux = 0, 0
    test_acc = 0
    for batch_idx in range(args.iterations):

        ####################
        # Train
        ####################
        data = train_loader.get_task_batch(
            batch_size=args.batch_size,
            n_way=args.train_N_way,
            unlabeled_extra=args.unlabeled_extra,
            num_shots=args.train_N_shots,
            cuda=args.cuda,
            variable=True)
        [
            batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi,
            hidden_labels
        ] = data

        opt_enc_nn.zero_grad()
        opt_metric_nn.zero_grad()

        loss_d_metric = train_batch(model=[enc_nn, metric_nn, softmax_module],
                                    data=[
                                        batch_x, label_x, batches_xi,
                                        labels_yi, oracles_yi, hidden_labels
                                    ])

        opt_enc_nn.step()
        opt_metric_nn.step()

        adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn],
                             lr=args.lr,
                             iter=batch_idx)

        ####################
        # Display
        ####################
        counter += 1
        total_loss += loss_d_metric.data[0]
        if batch_idx % args.log_interval == 0:
            display_str = 'Train Iter: {}'.format(batch_idx)
            display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss /
                                                            counter)
            logger.info(display_str)
            counter = 0
            total_loss = 0

        ####################
        # Test
        ####################
        if (batch_idx + 1) % args.test_interval == 0 or batch_idx == 20:
            if batch_idx == 20:
                test_samples = 100
            else:
                test_samples = 3000
            if args.dataset == 'mini_imagenet':
                val_acc_aux = test.test_one_shot(
                    args,
                    model=[enc_nn, metric_nn, softmax_module],
                    test_samples=test_samples * 5,
                    partition='val')
            test_acc_aux = test.test_one_shot(
                args,
                model=[enc_nn, metric_nn, softmax_module],
                test_samples=test_samples * 5,
                partition='test')
            test.test_one_shot(args,
                               model=[enc_nn, metric_nn, softmax_module],
                               test_samples=test_samples,
                               partition='train')
            enc_nn.train()
            metric_nn.train()

            if val_acc_aux is not None and val_acc_aux >= val_acc:
                test_acc = test_acc_aux
                val_acc = val_acc_aux

            if args.dataset == 'mini_imagenet':
                logger.info("Best test accuracy {:.4f} \n".format(test_acc))

        ####################
        # Save model
        ####################
        if (batch_idx + 1) % args.save_interval == 0:
            logger.info("saving model...")
            torch.save(enc_nn,
                       os.path.join(logger.get_logger_dir(), 'enc_nn.t7'))
            torch.save(metric_nn,
                       os.path.join(logger.get_logger_dir(), 'metric_nn.t7'))

    # Test after training
    test.test_one_shot(args,
                       model=[enc_nn, metric_nn, softmax_module],
                       test_samples=args.test_samples)
Esempio n. 3
0
def test_one_shot(e_stop,
                  test_cycle,
                  args,
                  model,
                  test_samples=5000,
                  partition='test'):
    """Function used to perform testing (validation and final test)
    
    Parameters:
    e_stop (early_stop): early stop which monitors when we stop training
    test_cycle (int): current iteration cycle - used when creating confusion matrix file names
    args (Namespace): arguments from argparse
    model (list) : contains cnn, gnn and softmax models
    test_samples (int) : number of samples to test
    partition (string) : train, val or test 
    
    Returns:
    e_stop (early_stop) : we return the early_stop object which has been updated based upon the test
    
    """

    io = io_utils.IOStream('checkpoints/' + args.exp_name + '/run.log')

    io.cprint('\n**** TESTING WITH %s ***' % (partition, ))

    loader = generator.Generator(args.dataset_root,
                                 args,
                                 partition=partition,
                                 dataset=args.dataset)

    [enc_nn, metric_nn, softmax_module] = model
    enc_nn.eval()
    metric_nn.eval()
    correct = 0
    total = 0
    iterations = int(test_samples / args.batch_size_test)

    true_list = []
    predicted_list = []

    with open(
            os.path.join('datasets', 'compacted_datasets',
                         'sensor_label_decoder.pickle'), 'rb') as handle:
        label_decoder = pickle.load(handle)

    sep = '\\'
    for temp in range(0, len(label_decoder)):
        label_decoder[temp] = label_decoder[temp].rsplit(sep, 1)[1]

    for i in range(iterations):

        data, labels_dict = loader.get_task_batch(
            batch_size=args.batch_size_test)
        [x, labels_x_cpu, _, x_global, xi_s, labels_yi_cpu] = data

        if args.cuda:
            xi_s = [batch_xi.cuda() for batch_xi in xi_s]
            labels_yi = [label_yi.cuda() for label_yi in labels_yi_cpu]
            x = x.cuda()
        else:
            labels_yi = labels_yi_cpu

        xi_s = [Variable(batch_xi) for batch_xi in xi_s]
        labels_yi = [Variable(label_yi) for label_yi in labels_yi]
        x = Variable(x)

        # Compute embedding from x and xi_s
        z = enc_nn(x)

        zi_s = [enc_nn(batch_xi) for batch_xi in xi_s]

        # Compute metric from embeddings
        output, out_logits = metric_nn(inputs=[z, zi_s, labels_yi])
        output = out_logits

        y_pred = softmax_module.forward(output)

        y_pred = y_pred.data.cpu().numpy()
        y_pred = np.argmax(y_pred, axis=1)
        labels_x_cpu = labels_x_cpu.numpy()
        labels_x_cpu = np.argmax(labels_x_cpu, axis=1)

        for i in range(0, len(labels_x_cpu)):
            true_label = labels_dict[i, labels_x_cpu[i]]
            true_list.append(label_decoder[true_label])
            predicted_label = labels_dict[i, y_pred[i]]
            predicted_list.append(label_decoder[predicted_label])

        for row_i in range(y_pred.shape[0]):
            if y_pred[row_i] == labels_x_cpu[row_i]:
                correct += 1
            total += 1

        if (i + 1) % 100 == 0:
            io.cprint('{} correct from {} \tAccuracy: {:.3f}%)'.format(
                correct, total, 100.0 * correct / total))
    acc = accuracy_score(true_list, predicted_list)
    micro = f1_score(true_list, predicted_list, average='weighted')
    macro = f1_score(true_list, predicted_list, average='macro')

    e_stop.update(micro, enc_nn, metric_nn)

    if partition == 'test' or (partition == 'val' and e_stop.improve):
        if partition == 'test':
            test_cycle = 999

        #Print confusion matrix
        conf_mat.conf_mat(true_list, predicted_list, test_cycle, args.train,
                          args.test)

        test_labels = sorted(set(true_list).union(set(predicted_list)))

        print(
            classification_report(true_list,
                                  predicted_list,
                                  target_names=test_labels))

        print("Micro is " + str(micro))

    enc_nn.train()
    metric_nn.train()

    return e_stop
Esempio n. 4
0
def train():
    """Main function used for training for model. Keeps iterating and updating parameters until early stop condition is reached."""

    #Generator is used to sample bacthes.
    train_loader = generator.Generator(args.dataset_root,
                                       args,
                                       partition='train',
                                       dataset=args.dataset)

    io.cprint('Batch size: ' + str(args.batch_size))
    print("Learning rate is " + str(args.lr))

    #Try to load models
    enc_nn = models.load_model('enc_nn', args, io)
    metric_nn = models.load_model('metric_nn', args, io)

    #creates models
    if enc_nn is None or metric_nn is None:
        enc_nn, metric_nn = models.create_models(args, train_loader)

    softmax_module = models.SoftmaxModule()
    if args.cuda:
        enc_nn.cuda()
        metric_nn.cuda()

    io.cprint(str(enc_nn))
    io.cprint(str(metric_nn))

    weight_decay = 0
    if args.dataset == 'sensor':
        print('Weight decay ' + str(1e-6))
        weight_decay = 1e-6

    opt_enc_nn = optim.Adam(filter(lambda p: p.requires_grad,
                                   enc_nn.parameters()),
                            lr=args.lr,
                            weight_decay=weight_decay)
    opt_metric_nn = optim.Adam(metric_nn.parameters(),
                               lr=args.lr,
                               weight_decay=weight_decay)

    enc_nn.train()
    metric_nn.train()
    counter = 0
    total_loss = 0
    test_cycle = 0
    batch_idx = 0

    start = time.time()
    print("starting time count")
    e_stop = early_stop.EarlyStopping()

    #Start training loop
    while e_stop.early_stop is False:
        ####################
        # Train
        ####################
        #Load training batch
        data, _ = train_loader.get_task_batch(batch_size=args.batch_size,
                                              cuda=args.cuda,
                                              variable=True)

        [batch_x, label_x, _, _, batches_xi, labels_yi] = data

        opt_enc_nn.zero_grad()
        opt_metric_nn.zero_grad()

        #Calculate loss
        loss_d_metric = train_batch(
            model=[enc_nn, metric_nn, softmax_module],
            data=[batch_x, label_x, batches_xi, labels_yi])
        #Update parameter
        opt_enc_nn.step()
        opt_metric_nn.step()

        #Adjust learning rate
        adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn],
                             lr=args.lr,
                             iter=batch_idx)

        ####################
        # Display
        ####################
        counter += 1
        total_loss += loss_d_metric.item()
        if batch_idx % args.log_interval == 0:
            display_str = 'Train Iter: {}'.format(batch_idx)
            display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss /
                                                            counter)
            io.cprint(display_str)
            counter = 0
            total_loss = 0

        ####################
        # Test
        ####################
        #Testing at specific itnervals
        if (batch_idx + 1) % args.test_interval == 0 or batch_idx == 0:
            if batch_idx == 20:
                test_samples = 200
            else:
                test_samples = 300

            e_stop = test.test_one_shot(
                e_stop,
                test_cycle,
                args,
                model=[enc_nn, metric_nn, softmax_module],
                test_samples=test_samples,
                partition='val')

            enc_nn.train()
            metric_nn.train()

            test_cycle = test_cycle + 1

            end = time.time()
            io.cprint("Time elapsed : " + str(end - start))
            print("Time elapsed : " + str(end - start))

        ####################
        # Save model
        ####################
        #Save model at specific interval
        if (batch_idx + 1) % args.save_interval == 0:
            torch.save(enc_nn,
                       'checkpoints/%s/models/enc_nn.t7' % args.exp_name)
            torch.save(metric_nn,
                       'checkpoints/%s/models/metric_nn.t7' % args.exp_name)

        batch_idx = batch_idx + 1

    #Test after training
    #Load best model
    final_enc_nn = models.load_best_model('enc_nn', io)
    final_metric_nn = models.load_best_model('metric_nn', io)

    final_enc_nn.cuda()
    final_metric_nn.cuda()

    test.test_one_shot(e_stop,
                       test_cycle,
                       args,
                       model=[final_enc_nn, final_metric_nn, softmax_module],
                       test_samples=args.test_samples,
                       partition='test')