Ejemplo n.º 1
0
def main():
    # set up the experiment directories
    if not args.log_off:
        exp_name = experiment_name_non_mnist()
        exp_dir = os.path.join(args.root_dir, exp_name)

        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)

        copy_script_to_folder(os.path.abspath(__file__), exp_dir)

        result_png_path = os.path.join(exp_dir, 'results.png')
        log = open(os.path.join(exp_dir, 'log.txt'.format(args.seed)), 'w')
        print_log('save path : {}'.format(exp_dir), log)
    else:
        log = None

    global best_acc

    state = {k: v for k, v in args._get_kwargs()}
    print("")
    print_log(state, log)
    print("")
    print_log("Random Seed: {}".format(args.seed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()),
              log)

    # dataloader
    train_loader, valid_loader, _, test_loader, num_classes = load_data_subset(
        args.batch_size,
        2,
        args.dataset,
        args.data_dir,
        labels_per_class=args.labels_per_class,
        valid_labels_per_class=args.valid_labels_per_class,
        mixup_alpha=args.mixup_alpha)

    if args.dataset == 'tiny-imagenet-200':
        stride = 2
        args.mean = torch.tensor([0.5] * 3,
                                 dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.std = torch.tensor([0.5] * 3,
                                dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.labels_per_class = 500
    elif args.dataset == 'cifar10':
        stride = 1
        args.mean = torch.tensor([x / 255 for x in [125.3, 123.0, 113.9]],
                                 dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.std = torch.tensor([x / 255 for x in [63.0, 62.1, 66.7]],
                                dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.labels_per_class = 5000
    elif args.dataset == 'cifar100':
        stride = 1
        args.mean = torch.tensor([x / 255 for x in [129.3, 124.1, 112.4]],
                                 dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.std = torch.tensor([x / 255 for x in [68.2, 65.4, 70.4]],
                                dtype=torch.float32).view(1, 3, 1, 1).cuda()
        args.labels_per_class = 500
    else:
        raise AssertionError('Given Dataset is not supported!')

    # create model
    print_log("=> creating model '{}'".format(args.arch), log)
    net = models.__dict__[args.arch](num_classes, args.dropout, stride).cuda()
    args.num_classes = num_classes

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    recorder = RecorderMeter(args.epochs)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch,
                                                     args.gammas,
                                                     args.schedule)
        if epoch == args.schedule[0]:
            args.clean_lam == 0

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)
        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        tr_acc, tr_acc5, tr_los = train(train_loader, net, optimizer, epoch,
                                        args, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, log)
        if (epoch % 50) == 0 and args.adv_p > 0:
            _, _ = validate(test_loader,
                            net,
                            log,
                            fgsm=True,
                            eps=4,
                            mean=args.mean,
                            std=args.std)
            _, _ = validate(test_loader,
                            net,
                            log,
                            fgsm=True,
                            eps=8,
                            mean=args.mean,
                            std=args.std)

        train_loss.append(tr_los)
        train_acc.append(tr_acc)
        test_loss.append(val_los)
        test_acc.append(val_acc)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

        if args.log_off:
            continue

        # save log
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }, is_best, exp_dir, 'checkpoint.pth.tar')

        dummy = recorder.update(epoch, tr_los, tr_acc, val_los, val_acc)
        if (epoch + 1) % 100 == 0:
            recorder.plot_curve(result_png_path)

        train_log = OrderedDict()
        train_log['train_loss'] = train_loss
        train_log['train_acc'] = train_acc
        train_log['test_loss'] = test_loss
        train_log['test_acc'] = test_acc

        pickle.dump(train_log, open(os.path.join(exp_dir, 'log.pkl'), 'wb'))
        plotting(exp_dir)

    acc_var = np.maximum(
        np.max(test_acc[-10:]) - np.median(test_acc[-10:]),
        np.median(test_acc[-10:]) - np.min(test_acc[-10:]))
    print_log(
        "\nfinal 10 epoch acc (median) : {:.2f} (+- {:.2f})".format(
            np.median(test_acc[-10:]), acc_var), log)

    if not args.log_off:
        log.close()
Ejemplo n.º 2
0
def train(return_model=False):
    examples, labels = load_data_subset('Data/train_indices.csv')

    print(min([len(ex) for ex in examples]))
    if truncate_threshold < 1:
        _, examples, labels = get_examples_below_length_threshold(
            examples, labels, threshold=truncate_threshold)

    print('Training set size: {}'.format(len(examples)))
    print('Num classes: {}'.format(len(sorted(list(set(labels))))))

    X, _ = pad_examples(examples)
    # Have to 'flatten' array to 2D to work with logsitic regression
    X = np.reshape(X, (X.shape[0], X.shape[1] * 3))
    train_label_set = sorted(list(set(labels)))
    ls = convert_labels(labels, train_label_set)
    y = np.argmax(ls, axis=1)

    print(X.shape)
    '''Split of fine-tuning set'''
    sss = StratifiedShuffleSplit(1, train_size=0.8)
    for train_index, val_index in sss.split(X=X, y=y):
        X_train = X[train_index]
        y_train = y[train_index]
        X_val = X[val_index]
        y_val = y[val_index]

    crossval_models = []
    crossval_accuracies = []
    crossval_training_accuracies = []
    '''Cross validation'''
    for fold in range(num_folds):
        '''Split off fold'''
        sss = StratifiedShuffleSplit(1, train_size=0.8)
        for trainfold_index, valfold_index in sss.split(X=X_train, y=y_train):
            X_trainfold = X_train[trainfold_index]
            y_trainfold = y_train[trainfold_index]
            X_valfold = X_train[valfold_index]
            y_valfold = y_train[valfold_index]

        lr = LogisticRegression(class_weight='balanced',
                                multi_class='multinomial',
                                solver='newton-cg',
                                max_iter=200,
                                tol=1e-4,
                                C=C)
        lr.fit(X_trainfold, y_trainfold)

        predictions = lr.predict(X_valfold)
        accuracy = accuracy_score(y_true=y_valfold, y_pred=predictions)

        crossval_accuracies.append(accuracy)
        crossval_training_accuracies.append(
            accuracy_score(y_true=y_trainfold, y_pred=lr.predict(X_trainfold)))
        crossval_models.append(lr)

        print('Fold {} accuracy: {}'.format(fold, accuracy))

    best_ID = crossval_accuracies.index(max(crossval_accuracies))
    best_model = crossval_models[best_ID]

    predictions = best_model.predict(X_val)
    accuracy = accuracy_score(y_true=y_val, y_pred=predictions)
    print('Best model training_accuracy: {}'.format(
        crossval_training_accuracies[best_ID]))
    print("Mean cross-val accuracy: {}".format(mean(crossval_accuracies)))
    print('Validation accuracy: {}'.format(accuracy))

    if return_model:
        return best_model
Ejemplo n.º 3
0
    shuffle_indicies = np.random.permutation(size)
    return X[shuffle_indicies], y[shuffle_indicies]


def get_batch(X, y, batch_num, batch_size):
    start = batch_num * batch_size
    end = (batch_num + 1) * batch_size

    if end >= X.shape[0]:
        return X[start:], y[start:]
    else:
        return X[start:end], y[start:end]


"""Get Data"""
examples, labels = load_data_subset(indices_path="Data/train_indices.csv")
length_threshold, examples, labels = get_examples_below_length_threshold(
    examples, labels, threshold=length_threshold)

X, X_masks = pad_examples(examples)
y = convert_labels(labels, sorted(list(set(labels))))
"""Split off validation set"""
sss = StratifiedShuffleSplit(1, train_size=0.8)
for train_index, val_idex in sss.split(X=X, y=y):
    X_train = X[train_index]
    y_train = y[train_index]
    X_val = X[val_idex]
    y_val = y[val_idex]

for fold in range(num_folds):
    """Split off cross validation fold"""
Ejemplo n.º 4
0
    print('Best model training_accuracy: {}'.format(
        crossval_training_accuracies[best_ID]))
    print("Mean cross-val accuracy: {}".format(mean(crossval_accuracies)))
    print('Validation accuracy: {}'.format(accuracy))

    if return_model:
        return best_model


if not evaluate:
    train()
else:
    model = train(return_model=True)

    #Get max length in training set
    examples, labels = load_data_subset('Data/train_indices.csv')
    if truncate_threshold > 0:
        maxlen, _, _ = get_examples_below_length_threshold(
            examples, labels, threshold=truncate_threshold)

    #Load test data & truncate
    test_examples, test_labels = load_data_subset('Data/test_indices.csv')
    test_examples, _ = pad_examples(test_examples)
    if test_examples.shape[1] > maxlen:
        test_examples = test_examples[:, :maxlen, :]

    test_examples = np.reshape(
        test_examples, (test_examples.shape[0], test_examples.shape[1] * 3))
    test_label_set = sorted(list(set(test_labels)))
    print(test_label_set)
    test_labels = convert_labels(test_labels, test_label_set)