예제 #1
0
def validate(val_loader, model, criterion, epoch, distill_data, distill_labels,
             num_steps, normalize):
    """Validation phase. This involves training a model from scratch."""
    losses_eval = AverageMeter()
    top1_eval = AverageMeter()
    start = time.time()

    print('Number of steps for retraining during validation: ' +
          str(num_steps))

    # initialize a new model
    if args.target == "cifar10":
        model_eval = M.AlexCifarNetMeta(args).to(device=device)
    else:
        model_eval = M.LeNetMeta(args).to(device=device)
    optimizer = torch.optim.Adam(model_eval.parameters())
    model_eval.train()

    # train a model from scratch for the given number of steps
    # using only the synthetic images and their labels
    for ITER in range(num_steps):
        # sample a minibatch of n_i examples from x~, y~: x~', y~'
        perm = torch.randperm(distill_data.size(0))
        idx = perm[:args.inner_batch_size]
        fi_i = torch.stack([
            normalize(image) for image in distill_data[idx].detach().cpu()
        ]).to(device=device)
        lb_i = distill_labels[idx]

        fi_o = model_eval(fi_i)
        loss = soft_cross_entropy(fi_o, lb_i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # evaluate the validation error
    # switch to evaluate mode
    model_eval.eval()

    for i, (input_, target) in enumerate(val_loader):
        input_ = input_.to(device=device)
        target = target.to(device=device)
        output = model_eval(input_)
        loss = criterion(output, target)

        # measure error rate and record loss
        err1 = compute_error_rate(output.data, target, topk=(1, ))[0]
        top1_eval.update(err1.item(), input_.size(0))
        losses_eval.update(loss.item(), input_.size(0))

    val_time = time.time() - start
    print(
        '* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Val Loss {loss.avg:.3f} Time {val_time:.3f}'
        .format(epoch,
                args.epochs,
                top1=top1_eval,
                loss=losses_eval,
                val_time=val_time))

    return top1_eval.avg, losses_eval.avg
예제 #2
0
def train(train_loader, model, distill_data, distill_labels, criterion,
          data_opt, epoch, optimizer, ma_list, ma_sum, lowest_ma_sum,
          current_num_steps, num_steps_list, num_steps_from_min, normalize):
    """
    Do one epoch of training the synthetic images.

    Parameters ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min
    are used for keeping track of statistics for model resets across epochs.
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    model_losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    for i, (input_, target) in enumerate(train_loader):
        # sample a minibatch of n_o target dataset examples x_t' with labels y_t'

        # measure data loading time
        data_time.update(time.time() - end)

        input_ = input_.to(device=device)
        target = target.to(device=device)

        # sample a minibatch of n_i examples from x~, y~: x~', y~'
        perm = torch.randperm(distill_data.size(0))
        idx = perm[:args.inner_batch_size]
        # normalize the synthetic images using the standard normalization
        fi_i = torch.stack([
            normalize(image) for image in distill_data[idx].cpu()
        ]).to(device=device)
        lb_i = distill_labels[idx]

        # inner loop
        for weight in model.parameters():
            weight.fast = None
        fi_o = model(fi_i)
        loss = soft_cross_entropy(fi_o, lb_i)
        optimizer.zero_grad()
        grad = torch.autograd.grad(loss, model.parameters(), create_graph=True)
        # create fast weights so that we can use second-order gradient
        for k, weight in enumerate(model.parameters()):
            weight.fast = weight - args.inner_lr * grad[k]

        # outer loop
        # update x~ <-- x~ - beta nabla_x~ L(f_w(f_theta(x_t)), y_t)
        # fast weights will be used
        logit = model(input_)
        data_loss = criterion(logit, target)
        data_opt.zero_grad()
        data_loss.backward(retain_graph=False)
        data_opt.step()

        # make sure the synthetic examples have valid values after the update
        distill_data.data = torch.clamp(distill_data.data, 0, 1)

        # calculate the loss again for updating the features
        fi_i = torch.stack([
            normalize(image) for image in distill_data[idx].cpu()
        ]).to(device=device)
        lb_i = distill_labels[idx]
        for weight in model.parameters():
            weight.fast = None
        fi_o = model(fi_i)
        loss = soft_cross_entropy(fi_o, lb_i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model_losses.update(loss.item(), input_.size(0))

        # measure error rate and record loss
        err1 = compute_error_rate(logit.data, target,
                                  topk=(1, ))[0]  # it returns a list

        losses.update(data_loss.item(), input_.size(0))
        top1.update(err1.item(), input_.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # update the moving average statistics
        # for image distillation, the values are fixed to 50
        if len(ma_list) < 50:
            ma_list.append(err1.item())
            ma_sum += err1.item()
            current_num_steps += 1

            current_ma = ma_sum / len(ma_list)

            if current_num_steps == 50:
                lowest_ma_sum = ma_sum
                num_steps_from_min = 0
        else:
            ma_sum = ma_sum - ma_list[0] + err1.item()
            ma_list = ma_list[1:] + [err1.item()]
            current_num_steps += 1

            current_ma = ma_sum / len(ma_list)

            if ma_sum < lowest_ma_sum:
                lowest_ma_sum = ma_sum
                num_steps_from_min = 0
            elif num_steps_from_min < 50:
                num_steps_from_min += 1
            else:
                # do early stopping
                num_steps_list.append(current_num_steps - num_steps_from_min -
                                      1)
                # restart all metrics
                ma_list = []
                ma_sum = 0
                lowest_ma_sum = 999999999
                current_num_steps = 0
                num_steps_from_min = 0

                # restart the model and the optimizer
                if args.target == "cifar10":
                    model = M.AlexCifarNetMeta(args).to(device=device)
                else:
                    model = M.LeNetMeta(args).to(device=device)
                optimizer = torch.optim.Adam(model.parameters())

                print('Model restarted after ' + str(num_steps_list[-1]) +
                      ' steps')

        if i % args.print_freq == 0 and args.verbose is True:
            print('Epoch: [{0}/{1}][{2}/{3}]\t'
                  'LR: {LR:.6f}\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t'
                  'MAvg 1-err {current_ma:.4f}'.format(epoch,
                                                       args.epochs,
                                                       i,
                                                       len(train_loader),
                                                       LR=1,
                                                       batch_time=batch_time,
                                                       data_time=data_time,
                                                       loss=losses,
                                                       top1=top1,
                                                       current_ma=current_ma))

    print(
        '* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f}  Train Loss {loss.avg:.3f}'
        .format(epoch, args.epochs, top1=top1, loss=losses))

    return top1.avg, losses.avg, distill_data, model_losses.avg, \
        ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, \
        model, optimizer
예제 #3
0
def main():
    global args, best_err1, device, num_classes
    args = get_args()
    torch.manual_seed(args.random_seed)
    best_err1 = 100

    # define datasets
    if args.target == "mnist":
        normalize = transforms.Normalize((0.1307, ), (0.3081, ))
        transform_train = transforms.Compose(
            [transforms.ToTensor(), normalize])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        train_set_all = datasets.MNIST('data',
                                       train=True,
                                       transform=transform_train,
                                       target_transform=None,
                                       download=True)
        # set aside 10000 examples from the training set for validation
        train_set, val_set = torch.utils.data.random_split(
            train_set_all, [50000, 10000])
        # if we do experiments with variable target set size, this will take care of it
        target_set_size = min(50000, args.target_set_size)
        train_set, _ = torch.utils.data.random_split(
            train_set, [target_set_size, 50000 - target_set_size])
        test_set = datasets.MNIST('data',
                                  train=False,
                                  transform=transform_test,
                                  target_transform=None,
                                  download=True)
        num_classes = 10
        num_channels = 1
        input_size = 28
    elif args.target == "cifar10":
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose(
            [transforms.Resize((32, 32)),
             transforms.ToTensor(), normalize])

        transform_test = transforms.Compose(
            [transforms.Resize((32, 32)),
             transforms.ToTensor(), normalize])
        train_set_all = datasets.CIFAR10('data',
                                         train=True,
                                         transform=transform_train,
                                         target_transform=None,
                                         download=True)
        # set aside 5000 examples from the training set for validation
        train_set, val_set = torch.utils.data.random_split(
            train_set_all, [45000, 5000])
        # if we do experiments with variable target set size, this will take care of it
        target_set_size = min(45000, args.target_set_size)
        train_set, _ = torch.utils.data.random_split(
            train_set, [target_set_size, 45000 - target_set_size])
        test_set = datasets.CIFAR10('data',
                                    train=False,
                                    transform=transform_test,
                                    target_transform=None,
                                    download=True)
        num_classes = 10
        num_channels = 3
        input_size = 32
    else:
        raise "The dataset is not currently supported"

    # create data loaders
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    if torch.cuda.is_available():  # checks whether a cuda gpu is available
        device = torch.cuda.current_device()
        print("use GPU", device)
        print("GPU ID {}".format(torch.cuda.current_device()))
    else:
        print("use CPU")
        device = torch.device('cpu')  # sets the device to be CPU

    # randomly initualize the images and create associated images
    # labels [0, 1, 2, ..., 0, 1, 2, ...]
    distill_labels = torch.arange(num_classes, dtype=torch.long, device=device) \
        .repeat(args.num_base_examples // num_classes, 1).reshape(-1)
    distill_labels = one_hot(distill_labels, num_classes)
    distill_data = torch.rand(args.num_base_examples,
                              num_channels,
                              input_size,
                              input_size,
                              device=device,
                              requires_grad=True)
    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().to(device=device)
    data_opt = torch.optim.Adam([distill_data])
    cudnn.benchmark = True
    M.LeNetMeta.meta = True
    M.AlexCifarNetMeta.meta = True

    # define the models to use
    if args.target == "cifar10":
        model = M.AlexCifarNetMeta(args).to(device=device)
    else:
        model = M.LeNetMeta(args).to(device=device)
    optimizer = torch.optim.Adam(model.parameters())

    create_json_experiment_log()

    # start measuring time
    start_time = time.time()

    # initialize early stopping variables
    ma_list = []
    ma_sum = 0
    lowest_ma_sum = 999999999
    current_num_steps = 0
    num_steps_list = []
    num_steps_from_min = 0

    val_err1 = 100.0
    val_loss = 5.0
    num_steps_val = 0

    with tqdm.tqdm(total=args.epochs) as pbar_epochs:
        for epoch in range(0, args.epochs):
            train_err1, train_loss, distill_data, model_loss, ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, model, optimizer = \
                train(train_loader, model, distill_data, distill_labels, criterion, data_opt, epoch, optimizer,
                      ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, normalize)
            # evaluate on the validation set only every 5 epochs as it can be quite expensive to train a new model from scratch
            if epoch % 5 == 4:
                # calculate the number of steps to use
                if len(num_steps_list) == 0:
                    num_steps_val = current_num_steps
                else:
                    num_steps_val = int(np.mean(num_steps_list[-3:]))

                val_err1, val_loss = validate(val_loader, model, criterion,
                                              epoch, distill_data,
                                              distill_labels, num_steps_val,
                                              normalize)
                # otherwise the stats keep the previous value

                if val_err1 <= best_err1:
                    best_distill_data = distill_data.detach().clone()
                    best_num_steps = num_steps_val
                    best_err1 = min(val_err1, best_err1)

                print('Current best val error (top-1 error):', best_err1)

            pbar_epochs.update(1)

            experiment_update_dict = {
                'train_top_1_error': train_err1,
                'train_loss': train_loss,
                'val_top_1_error': val_err1,
                'val_loss': val_loss,
                'model_loss': model_loss,
                'epoch': epoch,
                'num_val_steps': num_steps_val
            }
            # save the best images so that we can analyse them
            if epoch == args.epochs - 1:
                experiment_update_dict['data'] = best_distill_data.tolist()

            update_json_experiment_log_dict(experiment_update_dict)

    print('Best val error (top-1 error):', best_err1)

    # stop measuring time
    experiment_update_dict = {'total_train_time': time.time() - start_time}
    update_json_experiment_log_dict(experiment_update_dict)

    # this does number of steps analysis - what happens if we do more or fewer steps for training
    if args.num_steps_analysis:
        num_steps_add = [-50, -20, -10, 0, 10, 20, 50, 100]

        for num_steps_add_item in num_steps_add:
            # start measuring time for testing
            start_time = time.time()
            local_errs = []
            local_losses = []
            local_num_steps = best_num_steps + num_steps_add_item
            print('Number of steps for training: ' + str(local_num_steps))
            # each number of steps will have a robust estimate by using 20 repetitions
            for test_i in range(20):
                print('Test repetition ' + str(test_i))
                test_err1, test_loss = test(test_loader, model, criterion,
                                            best_distill_data, distill_labels,
                                            local_num_steps, normalize)
                local_errs.append(test_err1)
                local_losses.append(test_loss)
                print('Test error (top-1 error):', test_err1)
            experiment_update_dict = {
                'test_top_1_error': local_errs,
                'test_loss': local_losses,
                'total_test_time': time.time() - start_time,
                'num_test_steps': local_num_steps
            }
            update_json_experiment_log_dict(experiment_update_dict)
    else:
        # evaluate on test set repeatedly for a robust estimate
        for test_i in range(20):
            print('Test repetition ' + str(test_i))
            test_err1, test_loss = test(test_loader, model, criterion,
                                        best_distill_data, distill_labels,
                                        best_num_steps, normalize)

            print('Test error (top-1 error):', test_err1)
            experiment_update_dict = {
                'test_top_1_error': test_err1,
                'test_loss': test_loss,
                'total_test_time': time.time() - start_time,
                'num_test_steps': best_num_steps
            }
            update_json_experiment_log_dict(experiment_update_dict)
예제 #4
0
def find_best_num_steps(val_loader, criterion, fixed_input, labels):
    """Calculate the best number of steps to use based on the validation set."""
    best_num_steps = 0
    lowest_err = 100
    errors_list = []
    num_steps_used = []

    # use a larger number of max steps when using more base examples
    if args.num_base_examples > 100:
        max_num_steps = 1701
    else:
        max_num_steps = 1000

    # initialize a new model
    if args.target == "cifar10" or args.target == "cifar100":
        if args.resnet:
            model_eval = M.ResNetMeta(dataset=args.target,
                                      depth=18,
                                      num_classes=num_classes,
                                      bottleneck=False,
                                      device=device).to(device=device)
        else:
            model_eval = M.AlexCifarNetMeta(args).to(device=device)
    else:
        model_eval = M.LeNetMeta(args).to(device=device)
    optimizer = torch.optim.Adam(model_eval.parameters())
    model_eval.train()

    for ITER in range(max_num_steps):
        model_eval.train()
        # sample a minibatch of n_i examples from x~, y~: x~', y~'
        perm = torch.randperm(fixed_input.size(0))
        idx = perm[:args.inner_batch_size]
        fi_i = fixed_input[idx]
        lb_i = labels[idx].detach()

        fi_o = model_eval(fi_i)
        loss = soft_cross_entropy(fi_o, lb_i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # for larger numbers of base examples we decrease the frequency of evaluation
        if args.num_base_examples > 100:
            if ITER in set(
                [9, 24, 74, 100, 200, 300, 500, 700, 1000, 1200, 1500, 1700]):
                validate_now = True
            else:
                validate_now = False
        else:
            if ITER % 50 == 49 or ITER in set([9, 24, 74]):
                validate_now = True
            else:
                validate_now = False

        if validate_now:
            # switch to evaluate mode
            model_eval.eval()
            top1_eval = AverageMeter()
            losses_eval = AverageMeter()

            for i, (input_, target) in enumerate(val_loader):
                input_ = input_.to(device=device)
                target = target.to(device=device)
                output = model_eval(input_)
                loss = criterion(output, target)

                # measure error and record loss
                err1 = compute_error_rate(output.data, target, topk=(1, ))[0]
                top1_eval.update(err1.item(), input_.size(0))
                losses_eval.update(loss.item(), input_.size(0))

            errors_list.append(top1_eval.avg)
            num_steps_used.append(ITER + 1)

            if top1_eval.avg < lowest_err:
                lowest_err = top1_eval.avg
                best_num_steps = ITER + 1

    print(num_steps_used)
    print(errors_list)
    return best_num_steps, errors_list, num_steps_used
예제 #5
0
def test(test_loader, model_name, criterion, fixed_input, labels, num_steps):
    """Test phase. This involves training a model from scratch."""
    losses_eval = AverageMeter()
    top1_eval = AverageMeter()
    start = time.time()

    print('Number of steps for retraining during test: ' + str(num_steps))

    # initialize a new model
    if args.target == "cifar10" or args.target == "cifar100":
        if args.test_various_models:
            if model_name == 'resnet':
                model_eval = M.ResNetMeta(dataset=args.target,
                                          depth=18,
                                          num_classes=num_classes,
                                          bottleneck=False,
                                          device=device).to(device=device)
            elif model_name == 'LeNet':
                model_eval = M.LeNetMeta(args).to(device=device)
            else:
                model_eval = M.AlexCifarNetMeta(args).to(device=device)
        else:
            if args.resnet:
                model_eval = M.ResNetMeta(dataset=args.target,
                                          depth=18,
                                          num_classes=num_classes,
                                          bottleneck=False,
                                          device=device).to(device=device)
            else:
                model_eval = M.AlexCifarNetMeta(args).to(device=device)
    else:
        model_eval = M.LeNetMeta(args).to(device=device)
    optimizer = torch.optim.Adam(model_eval.parameters())
    model_eval.train()

    # train a model from scratch for the given number of steps
    # using only the base examples and their synthetic labels
    for ITER in range(num_steps):
        # sample a minibatch of n_i examples from x~, y~: x~', y~'
        perm = torch.randperm(fixed_input.size(0))
        idx = perm[:args.inner_batch_size]
        fi_i = fixed_input[idx]
        lb_i = labels[idx].detach()

        fi_o = model_eval(fi_i)
        loss = soft_cross_entropy(fi_o, lb_i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # evaluate the test error
    # switch to evaluate mode
    model_eval.eval()

    for i, (input_, target) in enumerate(test_loader):
        input_ = input_.to(device=device)
        target = target.to(device=device)
        output = model_eval(input_)
        loss = criterion(output, target)

        # measure error and record loss
        err1 = compute_error_rate(output.data, target, topk=(1, ))[0]
        top1_eval.update(err1.item(), input_.size(0))
        losses_eval.update(loss.item(), input_.size(0))

    test_time = time.time() - start
    print('Testing with a model trained from scratch')
    print('Test time: ' + str(test_time))
    print('Test error (top-1 error): {top1_eval.avg:.4f}'.format(
        top1_eval=top1_eval))

    return top1_eval.avg, losses_eval.avg
예제 #6
0
def train(train_loader, model, fixed_input, labels, criterion, labels_opt,
          epoch, optimizer, ma_list, ma_sum, lowest_ma_sum, current_num_steps,
          num_steps_list, num_steps_from_min):
    """
    Do one epoch of training the synthetic labels.

    Parameters ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min
    are used for keeping track of statistics for model resets across epochs.
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    model_losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()

    # define over how many steps to calculate the moving average
    # for resetting the model and also how many steps to wait
    # we use the same value for both
    if args.target == "cifar100" or args.target == "k49":
        stats_gap = 100
    elif args.num_base_examples > 100:
        stats_gap = 200
    else:
        stats_gap = 50

    for i, (input_, target) in enumerate(train_loader):
        # sampled a minibatch of n_o target dataset examples x_t' with labels y_t'

        # measure data loading time
        data_time.update(time.time() - end)

        input_ = input_.to(device=device)
        target = target.to(device=device)

        # sample a minibatch of n_i base examples from x~, y~: x~', y~'
        perm = torch.randperm(fixed_input.size(0))
        idx = perm[:args.inner_batch_size]
        fi_i = fixed_input[idx]
        lb_i = labels[idx]

        # inner loop
        for weight in model.parameters():
            weight.fast = None
        fi_o = model(fi_i)
        loss = soft_cross_entropy(fi_o, lb_i)
        optimizer.zero_grad()
        grad = torch.autograd.grad(loss, model.parameters(), create_graph=True)
        # create fast weights so that we can use second-order gradient
        for k, weight in enumerate(model.parameters()):
            weight.fast = weight - args.inner_lr * grad[k]

        # outer loop
        # update y~ <-- y~ - beta nabla_y~ L(f_theta(x_t'), y_t')
        # fast weights will be used
        logit = model(input_)
        label_loss = criterion(logit, target)
        labels_opt.zero_grad()
        # retain_graph=True allows us to update the feature extractor
        # without calculating the loss again
        label_loss.backward(retain_graph=True)
        labels_opt.step()

        # normalize the labels to form a valid probability distribution
        labels.data = torch.clamp(labels.data, 0, 1)
        labels.data = labels.data / labels.data.sum(dim=1).unsqueeze(1)

        # now update the model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        model_losses.update(loss.item(), input_.size(0))

        # measure error and record loss
        err1 = compute_error_rate(logit.data, target,
                                  topk=(1, ))[0]  # it returns a list

        losses.update(label_loss.item(), input_.size(0))
        top1.update(err1.item(), input_.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # update the moving average statistics
        if len(ma_list) < stats_gap:
            ma_list.append(err1.item())
            ma_sum += err1.item()
            current_num_steps += 1

            current_ma = ma_sum / len(ma_list)

            if current_num_steps == stats_gap:
                lowest_ma_sum = ma_sum
                num_steps_from_min = 0
        else:
            ma_sum = ma_sum - ma_list[0] + err1.item()
            ma_list = ma_list[1:] + [err1.item()]
            current_num_steps += 1

            current_ma = ma_sum / len(ma_list)

            if ma_sum < lowest_ma_sum:
                lowest_ma_sum = ma_sum
                num_steps_from_min = 0
            elif num_steps_from_min < stats_gap:
                num_steps_from_min += 1
            else:
                # do early stopping
                num_steps_list.append(current_num_steps - num_steps_from_min -
                                      1)
                # restart all metrics
                ma_list = []
                ma_sum = 0
                lowest_ma_sum = 999999999
                current_num_steps = 0
                num_steps_from_min = 0

                # restart the model and the optimizer
                if args.target == "cifar10" or args.target == "cifar100":
                    if args.resnet:
                        model = M.ResNetMeta(dataset=args.target,
                                             depth=18,
                                             num_classes=num_classes,
                                             bottleneck=False,
                                             device=device).to(device=device)
                    else:
                        model = M.AlexCifarNetMeta(args).to(device=device)
                else:
                    model = M.LeNetMeta(args).to(device=device)
                optimizer = torch.optim.Adam(model.parameters())

                print('Model restarted after ' + str(num_steps_list[-1]) +
                      ' steps')

        if i % args.print_freq == 0 and args.verbose is True:
            print('Epoch: [{0}/{1}][{2}/{3}]\t'
                  'LR: {LR:.6f}\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t'
                  'MAvg 1-err {current_ma:.4f}'.format(epoch,
                                                       args.epochs,
                                                       i,
                                                       len(train_loader),
                                                       LR=1,
                                                       batch_time=batch_time,
                                                       data_time=data_time,
                                                       loss=losses,
                                                       top1=top1,
                                                       current_ma=current_ma))

    print(
        '* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f}  Train Loss {loss.avg:.3f}'
        .format(epoch, args.epochs, top1=top1, loss=losses))

    return top1.avg, losses.avg, labels, model_losses.avg, \
        ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, \
        model, optimizer
예제 #7
0
def main():
    global args, best_err1, device, num_classes
    args = get_args()
    torch.manual_seed(args.random_seed)

    # most cases have 10 classes
    # if there are more, then it will be reassigned
    num_classes = 10
    best_err1 = 100

    # define datasets
    if args.target == "mnist":
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_set_all = datasets.MNIST('data',
                                       train=True,
                                       transform=transform_train,
                                       target_transform=None,
                                       download=True)
        # set aside 10000 examples from the training set for validation
        train_set, val_set = torch.utils.data.random_split(
            train_set_all, [50000, 10000])
        # if we do experiments with variable target set size, this will take care of it
        # by default the target set size is 50000
        target_set_size = min(50000, args.target_set_size)
        train_set, _ = torch.utils.data.random_split(
            train_set, [target_set_size, 50000 - target_set_size])
        test_set = datasets.MNIST('data',
                                  train=False,
                                  transform=transform_test,
                                  target_transform=None,
                                  download=True)
    elif args.target == "kmnist":
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_set_all = datasets.KMNIST('data',
                                        train=True,
                                        transform=transform_train,
                                        target_transform=None,
                                        download=True)
        # set aside 10000 examples from the training set for validation
        train_set, val_set = torch.utils.data.random_split(
            train_set_all, [50000, 10000])
        target_set_size = min(50000, args.target_set_size)
        # if we do experiments with variable target set size, this will take care of it
        train_set, _ = torch.utils.data.random_split(
            train_set, [target_set_size, 50000 - target_set_size])
        test_set = datasets.KMNIST('data',
                                   train=False,
                                   transform=transform_test,
                                   target_transform=None,
                                   download=True)
    elif args.target == "k49":
        num_classes = 49
        train_images = np.load('./data/k49-train-imgs.npz')['arr_0']
        test_images = np.load('./data/k49-test-imgs.npz')['arr_0']
        train_labels = np.load('./data/k49-train-labels.npz')['arr_0']
        test_labels = np.load('./data/k49-test-labels.npz')['arr_0']

        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        # set aside about 10% of training data for validation
        train_set_all = K49Dataset(train_images,
                                   train_labels,
                                   transform=transform_train)
        train_set, val_set = torch.utils.data.random_split(
            train_set_all, [209128, 23237])

        # currently we do not support variable target set size for k49
        # enable this to use it
        # target_set_size = min(209128, args.target_set_size)
        # train_set, _ = torch.utils.data.random_split(
        #     train_set, [target_set_size, 209128 - target_set_size])
        test_set = K49Dataset(test_images,
                              test_labels,
                              transform=transform_test)
    elif args.target == "cifar10":
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose(
            [transforms.Resize((32, 32)),
             transforms.ToTensor(), normalize])

        transform_test = transforms.Compose(
            [transforms.Resize((32, 32)),
             transforms.ToTensor(), normalize])

        train_set_all = datasets.CIFAR10('data',
                                         train=True,
                                         transform=transform_train,
                                         target_transform=None,
                                         download=True)
        # set aside 5000 examples from the training set for validation
        train_set, val_set = torch.utils.data.random_split(
            train_set_all, [45000, 5000])
        # if we do experiments with variable target set size, this will take care of it
        target_set_size = min(45000, args.target_set_size)
        train_set, _ = torch.utils.data.random_split(
            train_set, [target_set_size, 45000 - target_set_size])
        test_set = datasets.CIFAR10('data',
                                    train=False,
                                    transform=transform_test,
                                    target_transform=None,
                                    download=True)
    elif args.target == "cifar100":
        num_classes = 100
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
        transform_train = transforms.Compose(
            [transforms.Resize((32, 32)),
             transforms.ToTensor(), normalize])

        transform_test = transforms.Compose(
            [transforms.Resize((32, 32)),
             transforms.ToTensor(), normalize])

        train_set_all = datasets.CIFAR100('data',
                                          train=True,
                                          transform=transform_train,
                                          target_transform=None,
                                          download=True)
        # set aside 5000 examples from the training set for validation
        train_set, val_set = torch.utils.data.random_split(
            train_set_all, [45000, 5000])
        # if we do experiments with variable target set size, this will take care of it
        target_set_size = min(45000, args.target_set_size)
        train_set, _ = torch.utils.data.random_split(
            train_set, [target_set_size, 45000 - target_set_size])
        test_set = datasets.CIFAR100('data',
                                     train=False,
                                     transform=transform_test,
                                     target_transform=None,
                                     download=True)

    # create data loaders
    if args.baseline:
        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=args.num_base_examples, shuffle=True)
    else:
        train_loader = torch.utils.data.DataLoader(train_set,
                                                   batch_size=args.batch_size,
                                                   shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.batch_size,
                                              shuffle=False)

    # create data loaders to get base examples
    if args.source == "emnist":
        train_set_source = datasets.EMNIST('data',
                                           'letters',
                                           train=True,
                                           download=True,
                                           transform=transform_train,
                                           target_transform=None)
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)
    elif args.source == "mnist":
        train_set_source = datasets.MNIST('data',
                                          train=True,
                                          download=True,
                                          transform=transform_train,
                                          target_transform=None)
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)
    elif args.source == "kmnist":
        train_set_source = datasets.KMNIST('data',
                                           train=True,
                                           download=True,
                                           transform=transform_train,
                                           target_transform=None)
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)
    elif args.source == "cifar10":
        train_set_source = datasets.CIFAR10('data',
                                            train=True,
                                            download=True,
                                            transform=transform_train,
                                            target_transform=None)
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)
    elif args.source == "cifar100":
        train_set_source = datasets.CIFAR100('data',
                                             train=True,
                                             download=True,
                                             transform=transform_train,
                                             target_transform=None)
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)
    elif args.source == "svhn":
        train_set_source = datasets.SVHN('data',
                                         split='train',
                                         download=True,
                                         transform=transform_train,
                                         target_transform=None)
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)
    elif args.source == "cub":
        # modify the root depending on where you place the images
        cub_data_root = './data/CUB_200_2011/images'
        train_set_source = datasets.ImageFolder(cub_data_root,
                                                transform=transform_train,
                                                target_transform=None)
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)
    elif args.source == "fake":
        # there is also an option to use random noise base examples
        if args.target == "mnist":
            num_channels = 1
            dims = 28
        else:
            num_channels = 3
            dims = 32
        train_set_source = datasets.FakeData(size=5000,
                                             image_size=(num_channels, dims,
                                                         dims),
                                             num_classes=10,
                                             transform=transform_train,
                                             target_transform=None,
                                             random_offset=0)
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)
    else:
        # get the fixed images from the same dataset as the training data
        train_set_source = train_set
        train_loader_source = torch.utils.data.DataLoader(
            train_set_source, batch_size=args.num_base_examples, shuffle=True)

    if torch.cuda.is_available():  # checks whether a cuda gpu is available
        device = torch.cuda.current_device()

        print("use GPU", device)
        print("GPU ID {}".format(torch.cuda.current_device()))
    else:
        print("use CPU")
        device = torch.device('cpu')  # sets the device to be CPU

    train_loader_source_iter = iter(train_loader_source)

    if args.balanced_source:
        # use a balanced set of fixed examples - same number of examples per class
        class_counts = {}
        fixed_input = []
        fixed_target = []

        for batch_fixed_i, batch_fixed_t in train_loader_source_iter:
            if sum(class_counts.values()) >= args.num_base_examples:
                break
            for fixed_i, fixed_t in zip(batch_fixed_i, batch_fixed_t):
                if len(class_counts.keys()) < num_classes:
                    if int(fixed_t) in class_counts:
                        if class_counts[
                                int(fixed_t
                                    )] < args.num_base_examples // num_classes:
                            class_counts[int(fixed_t)] += 1
                            fixed_input.append(fixed_i)
                            fixed_target.append(int(fixed_t))
                    else:
                        class_counts[int(int(fixed_t))] = 1
                        fixed_input.append(fixed_i)
                        fixed_target.append(int(fixed_t))
                else:
                    if int(fixed_t) in class_counts:
                        if class_counts[
                                int(fixed_t
                                    )] < args.num_base_examples // num_classes:
                            class_counts[int(fixed_t)] += 1
                            fixed_input.append(fixed_i)
                            fixed_target.append(int(fixed_t))
        fixed_input = torch.stack(fixed_input).to(device=device)
        fixed_target = torch.Tensor(fixed_target).to(device=device)
    else:
        # used for cross-dataset scenario - random selection of classes
        # not taking into accound the original classes
        fixed_input, fixed_target = next(train_loader_source_iter)
        fixed_input = fixed_input.to(device=device)
        fixed_target = fixed_target.to(device=device)

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss().to(device=device)

    # start at uniform labels and then learn them
    labels = torch.zeros((args.num_base_examples, num_classes),
                         requires_grad=True,
                         device=device)
    labels = labels.new_tensor(
        [[float(1.0 / num_classes) for e in range(num_classes)]
         for i in range(args.num_base_examples)],
        requires_grad=True,
        device=device)
    # define an optimizer for labels
    labels_opt = torch.optim.Adam([labels])

    # enable using meta-architectures for second-order meta-learning
    # allows assigning fast weights
    cudnn.benchmark = True
    M.LeNetMeta.meta = True
    M.AlexCifarNetMeta.meta = True
    M.BasicBlockMeta.meta = True
    M.BottleneckMeta.meta = True
    M.ResNetMeta.meta = True

    # define the models to use
    if args.target == "cifar10" or args.target == "cifar100":
        if args.resnet:
            model = M.ResNetMeta(dataset=args.target,
                                 depth=18,
                                 num_classes=num_classes,
                                 bottleneck=False,
                                 device=device).to(device=device)
            model_name = 'resnet'
        else:
            model = M.AlexCifarNetMeta(args).to(device=device)
            model_name = 'alexnet'
    else:
        model = M.LeNetMeta(args).to(device=device)
        model_name = 'LeNet'
    optimizer = torch.optim.Adam(model.parameters())

    if args.baseline:
        create_json_experiment_log(fixed_target)
        # remap the targets - only relevant in cross-dataset
        fixed_target = remap_targets(fixed_target, num_classes)
        # printing the labels helps ensure the seeds work
        print('The labels of the fixed examples are')
        print(fixed_target.tolist())
        labels = one_hot(fixed_target.long(), num_classes)

        # add smoothing to the baseline if selected
        if args.label_smoothing > 0:
            labels = create_smooth_labels(labels, args.label_smoothing,
                                          num_classes)

        # use the validation set to find a suitable number of iterations for training
        num_baseline_steps, errors_list, num_steps_used = find_best_num_steps(
            val_loader, criterion, fixed_input, labels)
        print('Number of steps to use for the baseline: ' +
              str(num_baseline_steps))
        experiment_update_dict = {
            'num_baseline_steps': num_baseline_steps,
            'errors_list': errors_list,
            'num_steps_used': num_steps_used
        }
        update_json_experiment_log_dict(experiment_update_dict)

        if args.test_various_models:
            assert args.target == "cifar10", "test various models is only meant to be used for CIFAR-10"
            model_name_list = ['alexnet', 'LeNet', 'resnet']

            for model_name_test in model_name_list:
                # do 20 repetitions of training from scratch
                for test_i in range(20):
                    print('Test repetition ' + str(test_i))
                    test_err1, test_loss = test(test_loader, model_name_test,
                                                criterion, fixed_input, labels,
                                                num_baseline_steps)
                    print('Test error (top-1 error):', test_err1)
                    experiment_update_dict = {
                        'test_top_1_error_' + model_name_test: test_err1,
                        'test_loss_' + model_name_test: test_loss,
                        'num_test_steps_' + model_name_test: num_baseline_steps
                    }
                    update_json_experiment_log_dict(experiment_update_dict)
        else:
            # do 20 repetitions of training from scratch
            for test_i in range(20):
                print('Test repetition ' + str(test_i))
                test_err1, test_loss = test(test_loader, model_name, criterion,
                                            fixed_input, labels,
                                            num_baseline_steps)
                print('Test error (top-1 error):', test_err1)
                experiment_update_dict = {
                    'test_top_1_error': test_err1,
                    'test_loss': test_loss,
                    'num_test_steps': num_baseline_steps
                }
                update_json_experiment_log_dict(experiment_update_dict)

    else:
        create_json_experiment_log(fixed_target)

        # start measuring time
        start_time = time.time()

        # initialize variables to decide when to restart a model
        ma_list = []
        ma_sum = 0
        lowest_ma_sum = 999999999
        current_num_steps = 0
        num_steps_list = []
        num_steps_from_min = 0

        val_err1 = 100.0
        val_loss = 5.0
        num_steps_val = 0

        with tqdm.tqdm(total=args.epochs) as pbar_epochs:
            for epoch in range(0, args.epochs):
                train_err1, train_loss, labels, model_loss, ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, model, optimizer = \
                    train(train_loader, model, fixed_input, labels, criterion, labels_opt, epoch, optimizer,
                          ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min)
                # evaluate on the validation set only every 5 epochs as it can be quite expensive to train a new model from scratch
                if epoch % 5 == 4:
                    # calculate the number of steps to use
                    if len(num_steps_list) == 0:
                        num_steps_val = current_num_steps
                    else:
                        num_steps_val = int(np.mean(num_steps_list[-3:]))

                    val_err1, val_loss = validate(val_loader, model, criterion,
                                                  epoch, fixed_input, labels,
                                                  num_steps_val)

                    if val_err1 <= best_err1:
                        best_labels = labels.detach().clone()
                        best_num_steps = num_steps_val
                        best_err1 = min(val_err1, best_err1)

                    print('Current best val error (top-1 error):', best_err1)

                pbar_epochs.update(1)

                experiment_update_dict = {
                    'train_top_1_error': train_err1,
                    'train_loss': train_loss,
                    'val_top_1_error': val_err1,
                    'val_loss': val_loss,
                    'model_loss': model_loss,
                    'epoch': epoch,
                    'num_val_steps': num_steps_val
                }
                # save the best labels so that we can analyse them
                if epoch == args.epochs - 1:
                    experiment_update_dict['labels'] = best_labels.tolist()

                update_json_experiment_log_dict(experiment_update_dict)

        print('Best val error (top-1 error):', best_err1)

        # stop measuring time
        experiment_update_dict = {'total_train_time': time.time() - start_time}
        update_json_experiment_log_dict(experiment_update_dict)

        # this does number of steps analysis - what happens if we do more or fewer steps for test training
        if args.num_steps_analysis:
            num_steps_add = [-50, -20, -10, 0, 10, 20, 50, 100]

            for num_steps_add_item in num_steps_add:
                # start measuring time for testing
                start_time = time.time()
                local_errs = []
                local_losses = []
                local_num_steps = best_num_steps + num_steps_add_item
                print('Number of steps for training: ' + str(local_num_steps))
                # each number of steps will have a robust estimate by using 20 repetitions
                for test_i in range(20):
                    print('Test repetition ' + str(test_i))
                    test_err1, test_loss = test(test_loader, model_name,
                                                criterion, fixed_input,
                                                best_labels, local_num_steps)
                    local_errs.append(test_err1)
                    local_losses.append(test_loss)
                    print('Test error (top-1 error):', test_err1)
                experiment_update_dict = {
                    'test_top_1_error': local_errs,
                    'test_loss': local_losses,
                    'total_test_time': time.time() - start_time,
                    'num_test_steps': local_num_steps
                }
                update_json_experiment_log_dict(experiment_update_dict)
        else:
            if args.test_various_models:
                assert args.target == "cifar10", "test various models is only meant to be used for CIFAR-10"
                model_name_list = ['alexnet', 'LeNet', 'resnet']

                for model_name_test in model_name_list:
                    for test_i in range(20):
                        print(model_name_test)
                        print('Test repetition ' + str(test_i))
                        test_err1, test_loss = test(test_loader,
                                                    model_name_test, criterion,
                                                    fixed_input, best_labels,
                                                    best_num_steps)
                        print('Test error (top-1 error):', test_err1)
                        experiment_update_dict = {
                            'test_top_1_error_' + model_name_test: test_err1,
                            'test_loss_' + model_name_test: test_loss,
                            'total_test_time_' + model_name_test:
                            time.time() - start_time,
                            'num_test_steps_' + model_name_test: best_num_steps
                        }
                        update_json_experiment_log_dict(experiment_update_dict)
            else:
                for test_i in range(20):
                    print('Test repetition ' + str(test_i))
                    test_err1, test_loss = test(test_loader, model_name,
                                                criterion, fixed_input,
                                                best_labels, best_num_steps)
                    print('Test error (top-1 error):', test_err1)
                    experiment_update_dict = {
                        'test_top_1_error': test_err1,
                        'test_loss': test_loss,
                        'total_test_time': time.time() - start_time,
                        'num_test_steps': best_num_steps
                    }
                    update_json_experiment_log_dict(experiment_update_dict)