def validate(val_loader, model, criterion, epoch, distill_data, distill_labels, num_steps, normalize): """Validation phase. This involves training a model from scratch.""" losses_eval = AverageMeter() top1_eval = AverageMeter() start = time.time() print('Number of steps for retraining during validation: ' + str(num_steps)) # initialize a new model if args.target == "cifar10": model_eval = M.AlexCifarNetMeta(args).to(device=device) else: model_eval = M.LeNetMeta(args).to(device=device) optimizer = torch.optim.Adam(model_eval.parameters()) model_eval.train() # train a model from scratch for the given number of steps # using only the synthetic images and their labels for ITER in range(num_steps): # sample a minibatch of n_i examples from x~, y~: x~', y~' perm = torch.randperm(distill_data.size(0)) idx = perm[:args.inner_batch_size] fi_i = torch.stack([ normalize(image) for image in distill_data[idx].detach().cpu() ]).to(device=device) lb_i = distill_labels[idx] fi_o = model_eval(fi_i) loss = soft_cross_entropy(fi_o, lb_i) optimizer.zero_grad() loss.backward() optimizer.step() # evaluate the validation error # switch to evaluate mode model_eval.eval() for i, (input_, target) in enumerate(val_loader): input_ = input_.to(device=device) target = target.to(device=device) output = model_eval(input_) loss = criterion(output, target) # measure error rate and record loss err1 = compute_error_rate(output.data, target, topk=(1, ))[0] top1_eval.update(err1.item(), input_.size(0)) losses_eval.update(loss.item(), input_.size(0)) val_time = time.time() - start print( '* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Val Loss {loss.avg:.3f} Time {val_time:.3f}' .format(epoch, args.epochs, top1=top1_eval, loss=losses_eval, val_time=val_time)) return top1_eval.avg, losses_eval.avg
def train(train_loader, model, distill_data, distill_labels, criterion, data_opt, epoch, optimizer, ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, normalize): """ Do one epoch of training the synthetic images. Parameters ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min are used for keeping track of statistics for model resets across epochs. """ batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() model_losses = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input_, target) in enumerate(train_loader): # sample a minibatch of n_o target dataset examples x_t' with labels y_t' # measure data loading time data_time.update(time.time() - end) input_ = input_.to(device=device) target = target.to(device=device) # sample a minibatch of n_i examples from x~, y~: x~', y~' perm = torch.randperm(distill_data.size(0)) idx = perm[:args.inner_batch_size] # normalize the synthetic images using the standard normalization fi_i = torch.stack([ normalize(image) for image in distill_data[idx].cpu() ]).to(device=device) lb_i = distill_labels[idx] # inner loop for weight in model.parameters(): weight.fast = None fi_o = model(fi_i) loss = soft_cross_entropy(fi_o, lb_i) optimizer.zero_grad() grad = torch.autograd.grad(loss, model.parameters(), create_graph=True) # create fast weights so that we can use second-order gradient for k, weight in enumerate(model.parameters()): weight.fast = weight - args.inner_lr * grad[k] # outer loop # update x~ <-- x~ - beta nabla_x~ L(f_w(f_theta(x_t)), y_t) # fast weights will be used logit = model(input_) data_loss = criterion(logit, target) data_opt.zero_grad() data_loss.backward(retain_graph=False) data_opt.step() # make sure the synthetic examples have valid values after the update distill_data.data = torch.clamp(distill_data.data, 0, 1) # calculate the loss again for updating the features fi_i = torch.stack([ normalize(image) for image in distill_data[idx].cpu() ]).to(device=device) lb_i = distill_labels[idx] for weight in model.parameters(): weight.fast = None fi_o = model(fi_i) loss = soft_cross_entropy(fi_o, lb_i) optimizer.zero_grad() loss.backward() optimizer.step() model_losses.update(loss.item(), input_.size(0)) # measure error rate and record loss err1 = compute_error_rate(logit.data, target, topk=(1, ))[0] # it returns a list losses.update(data_loss.item(), input_.size(0)) top1.update(err1.item(), input_.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # update the moving average statistics # for image distillation, the values are fixed to 50 if len(ma_list) < 50: ma_list.append(err1.item()) ma_sum += err1.item() current_num_steps += 1 current_ma = ma_sum / len(ma_list) if current_num_steps == 50: lowest_ma_sum = ma_sum num_steps_from_min = 0 else: ma_sum = ma_sum - ma_list[0] + err1.item() ma_list = ma_list[1:] + [err1.item()] current_num_steps += 1 current_ma = ma_sum / len(ma_list) if ma_sum < lowest_ma_sum: lowest_ma_sum = ma_sum num_steps_from_min = 0 elif num_steps_from_min < 50: num_steps_from_min += 1 else: # do early stopping num_steps_list.append(current_num_steps - num_steps_from_min - 1) # restart all metrics ma_list = [] ma_sum = 0 lowest_ma_sum = 999999999 current_num_steps = 0 num_steps_from_min = 0 # restart the model and the optimizer if args.target == "cifar10": model = M.AlexCifarNetMeta(args).to(device=device) else: model = M.LeNetMeta(args).to(device=device) optimizer = torch.optim.Adam(model.parameters()) print('Model restarted after ' + str(num_steps_list[-1]) + ' steps') if i % args.print_freq == 0 and args.verbose is True: print('Epoch: [{0}/{1}][{2}/{3}]\t' 'LR: {LR:.6f}\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 'MAvg 1-err {current_ma:.4f}'.format(epoch, args.epochs, i, len(train_loader), LR=1, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, current_ma=current_ma)) print( '* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Train Loss {loss.avg:.3f}' .format(epoch, args.epochs, top1=top1, loss=losses)) return top1.avg, losses.avg, distill_data, model_losses.avg, \ ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, \ model, optimizer
def main(): global args, best_err1, device, num_classes args = get_args() torch.manual_seed(args.random_seed) best_err1 = 100 # define datasets if args.target == "mnist": normalize = transforms.Normalize((0.1307, ), (0.3081, )) transform_train = transforms.Compose( [transforms.ToTensor(), normalize]) transform_test = transforms.Compose([transforms.ToTensor(), normalize]) train_set_all = datasets.MNIST('data', train=True, transform=transform_train, target_transform=None, download=True) # set aside 10000 examples from the training set for validation train_set, val_set = torch.utils.data.random_split( train_set_all, [50000, 10000]) # if we do experiments with variable target set size, this will take care of it target_set_size = min(50000, args.target_set_size) train_set, _ = torch.utils.data.random_split( train_set, [target_set_size, 50000 - target_set_size]) test_set = datasets.MNIST('data', train=False, transform=transform_test, target_transform=None, download=True) num_classes = 10 num_channels = 1 input_size = 28 elif args.target == "cifar10": normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) transform_train = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor(), normalize]) transform_test = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor(), normalize]) train_set_all = datasets.CIFAR10('data', train=True, transform=transform_train, target_transform=None, download=True) # set aside 5000 examples from the training set for validation train_set, val_set = torch.utils.data.random_split( train_set_all, [45000, 5000]) # if we do experiments with variable target set size, this will take care of it target_set_size = min(45000, args.target_set_size) train_set, _ = torch.utils.data.random_split( train_set, [target_set_size, 45000 - target_set_size]) test_set = datasets.CIFAR10('data', train=False, transform=transform_test, target_transform=None, download=True) num_classes = 10 num_channels = 3 input_size = 32 else: raise "The dataset is not currently supported" # create data loaders train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False) test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False) if torch.cuda.is_available(): # checks whether a cuda gpu is available device = torch.cuda.current_device() print("use GPU", device) print("GPU ID {}".format(torch.cuda.current_device())) else: print("use CPU") device = torch.device('cpu') # sets the device to be CPU # randomly initualize the images and create associated images # labels [0, 1, 2, ..., 0, 1, 2, ...] distill_labels = torch.arange(num_classes, dtype=torch.long, device=device) \ .repeat(args.num_base_examples // num_classes, 1).reshape(-1) distill_labels = one_hot(distill_labels, num_classes) distill_data = torch.rand(args.num_base_examples, num_channels, input_size, input_size, device=device, requires_grad=True) # define loss function (criterion) criterion = nn.CrossEntropyLoss().to(device=device) data_opt = torch.optim.Adam([distill_data]) cudnn.benchmark = True M.LeNetMeta.meta = True M.AlexCifarNetMeta.meta = True # define the models to use if args.target == "cifar10": model = M.AlexCifarNetMeta(args).to(device=device) else: model = M.LeNetMeta(args).to(device=device) optimizer = torch.optim.Adam(model.parameters()) create_json_experiment_log() # start measuring time start_time = time.time() # initialize early stopping variables ma_list = [] ma_sum = 0 lowest_ma_sum = 999999999 current_num_steps = 0 num_steps_list = [] num_steps_from_min = 0 val_err1 = 100.0 val_loss = 5.0 num_steps_val = 0 with tqdm.tqdm(total=args.epochs) as pbar_epochs: for epoch in range(0, args.epochs): train_err1, train_loss, distill_data, model_loss, ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, model, optimizer = \ train(train_loader, model, distill_data, distill_labels, criterion, data_opt, epoch, optimizer, ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, normalize) # evaluate on the validation set only every 5 epochs as it can be quite expensive to train a new model from scratch if epoch % 5 == 4: # calculate the number of steps to use if len(num_steps_list) == 0: num_steps_val = current_num_steps else: num_steps_val = int(np.mean(num_steps_list[-3:])) val_err1, val_loss = validate(val_loader, model, criterion, epoch, distill_data, distill_labels, num_steps_val, normalize) # otherwise the stats keep the previous value if val_err1 <= best_err1: best_distill_data = distill_data.detach().clone() best_num_steps = num_steps_val best_err1 = min(val_err1, best_err1) print('Current best val error (top-1 error):', best_err1) pbar_epochs.update(1) experiment_update_dict = { 'train_top_1_error': train_err1, 'train_loss': train_loss, 'val_top_1_error': val_err1, 'val_loss': val_loss, 'model_loss': model_loss, 'epoch': epoch, 'num_val_steps': num_steps_val } # save the best images so that we can analyse them if epoch == args.epochs - 1: experiment_update_dict['data'] = best_distill_data.tolist() update_json_experiment_log_dict(experiment_update_dict) print('Best val error (top-1 error):', best_err1) # stop measuring time experiment_update_dict = {'total_train_time': time.time() - start_time} update_json_experiment_log_dict(experiment_update_dict) # this does number of steps analysis - what happens if we do more or fewer steps for training if args.num_steps_analysis: num_steps_add = [-50, -20, -10, 0, 10, 20, 50, 100] for num_steps_add_item in num_steps_add: # start measuring time for testing start_time = time.time() local_errs = [] local_losses = [] local_num_steps = best_num_steps + num_steps_add_item print('Number of steps for training: ' + str(local_num_steps)) # each number of steps will have a robust estimate by using 20 repetitions for test_i in range(20): print('Test repetition ' + str(test_i)) test_err1, test_loss = test(test_loader, model, criterion, best_distill_data, distill_labels, local_num_steps, normalize) local_errs.append(test_err1) local_losses.append(test_loss) print('Test error (top-1 error):', test_err1) experiment_update_dict = { 'test_top_1_error': local_errs, 'test_loss': local_losses, 'total_test_time': time.time() - start_time, 'num_test_steps': local_num_steps } update_json_experiment_log_dict(experiment_update_dict) else: # evaluate on test set repeatedly for a robust estimate for test_i in range(20): print('Test repetition ' + str(test_i)) test_err1, test_loss = test(test_loader, model, criterion, best_distill_data, distill_labels, best_num_steps, normalize) print('Test error (top-1 error):', test_err1) experiment_update_dict = { 'test_top_1_error': test_err1, 'test_loss': test_loss, 'total_test_time': time.time() - start_time, 'num_test_steps': best_num_steps } update_json_experiment_log_dict(experiment_update_dict)
def find_best_num_steps(val_loader, criterion, fixed_input, labels): """Calculate the best number of steps to use based on the validation set.""" best_num_steps = 0 lowest_err = 100 errors_list = [] num_steps_used = [] # use a larger number of max steps when using more base examples if args.num_base_examples > 100: max_num_steps = 1701 else: max_num_steps = 1000 # initialize a new model if args.target == "cifar10" or args.target == "cifar100": if args.resnet: model_eval = M.ResNetMeta(dataset=args.target, depth=18, num_classes=num_classes, bottleneck=False, device=device).to(device=device) else: model_eval = M.AlexCifarNetMeta(args).to(device=device) else: model_eval = M.LeNetMeta(args).to(device=device) optimizer = torch.optim.Adam(model_eval.parameters()) model_eval.train() for ITER in range(max_num_steps): model_eval.train() # sample a minibatch of n_i examples from x~, y~: x~', y~' perm = torch.randperm(fixed_input.size(0)) idx = perm[:args.inner_batch_size] fi_i = fixed_input[idx] lb_i = labels[idx].detach() fi_o = model_eval(fi_i) loss = soft_cross_entropy(fi_o, lb_i) optimizer.zero_grad() loss.backward() optimizer.step() # for larger numbers of base examples we decrease the frequency of evaluation if args.num_base_examples > 100: if ITER in set( [9, 24, 74, 100, 200, 300, 500, 700, 1000, 1200, 1500, 1700]): validate_now = True else: validate_now = False else: if ITER % 50 == 49 or ITER in set([9, 24, 74]): validate_now = True else: validate_now = False if validate_now: # switch to evaluate mode model_eval.eval() top1_eval = AverageMeter() losses_eval = AverageMeter() for i, (input_, target) in enumerate(val_loader): input_ = input_.to(device=device) target = target.to(device=device) output = model_eval(input_) loss = criterion(output, target) # measure error and record loss err1 = compute_error_rate(output.data, target, topk=(1, ))[0] top1_eval.update(err1.item(), input_.size(0)) losses_eval.update(loss.item(), input_.size(0)) errors_list.append(top1_eval.avg) num_steps_used.append(ITER + 1) if top1_eval.avg < lowest_err: lowest_err = top1_eval.avg best_num_steps = ITER + 1 print(num_steps_used) print(errors_list) return best_num_steps, errors_list, num_steps_used
def test(test_loader, model_name, criterion, fixed_input, labels, num_steps): """Test phase. This involves training a model from scratch.""" losses_eval = AverageMeter() top1_eval = AverageMeter() start = time.time() print('Number of steps for retraining during test: ' + str(num_steps)) # initialize a new model if args.target == "cifar10" or args.target == "cifar100": if args.test_various_models: if model_name == 'resnet': model_eval = M.ResNetMeta(dataset=args.target, depth=18, num_classes=num_classes, bottleneck=False, device=device).to(device=device) elif model_name == 'LeNet': model_eval = M.LeNetMeta(args).to(device=device) else: model_eval = M.AlexCifarNetMeta(args).to(device=device) else: if args.resnet: model_eval = M.ResNetMeta(dataset=args.target, depth=18, num_classes=num_classes, bottleneck=False, device=device).to(device=device) else: model_eval = M.AlexCifarNetMeta(args).to(device=device) else: model_eval = M.LeNetMeta(args).to(device=device) optimizer = torch.optim.Adam(model_eval.parameters()) model_eval.train() # train a model from scratch for the given number of steps # using only the base examples and their synthetic labels for ITER in range(num_steps): # sample a minibatch of n_i examples from x~, y~: x~', y~' perm = torch.randperm(fixed_input.size(0)) idx = perm[:args.inner_batch_size] fi_i = fixed_input[idx] lb_i = labels[idx].detach() fi_o = model_eval(fi_i) loss = soft_cross_entropy(fi_o, lb_i) optimizer.zero_grad() loss.backward() optimizer.step() # evaluate the test error # switch to evaluate mode model_eval.eval() for i, (input_, target) in enumerate(test_loader): input_ = input_.to(device=device) target = target.to(device=device) output = model_eval(input_) loss = criterion(output, target) # measure error and record loss err1 = compute_error_rate(output.data, target, topk=(1, ))[0] top1_eval.update(err1.item(), input_.size(0)) losses_eval.update(loss.item(), input_.size(0)) test_time = time.time() - start print('Testing with a model trained from scratch') print('Test time: ' + str(test_time)) print('Test error (top-1 error): {top1_eval.avg:.4f}'.format( top1_eval=top1_eval)) return top1_eval.avg, losses_eval.avg
def train(train_loader, model, fixed_input, labels, criterion, labels_opt, epoch, optimizer, ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min): """ Do one epoch of training the synthetic labels. Parameters ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min are used for keeping track of statistics for model resets across epochs. """ batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() model_losses = AverageMeter() top1 = AverageMeter() # switch to train mode model.train() end = time.time() # define over how many steps to calculate the moving average # for resetting the model and also how many steps to wait # we use the same value for both if args.target == "cifar100" or args.target == "k49": stats_gap = 100 elif args.num_base_examples > 100: stats_gap = 200 else: stats_gap = 50 for i, (input_, target) in enumerate(train_loader): # sampled a minibatch of n_o target dataset examples x_t' with labels y_t' # measure data loading time data_time.update(time.time() - end) input_ = input_.to(device=device) target = target.to(device=device) # sample a minibatch of n_i base examples from x~, y~: x~', y~' perm = torch.randperm(fixed_input.size(0)) idx = perm[:args.inner_batch_size] fi_i = fixed_input[idx] lb_i = labels[idx] # inner loop for weight in model.parameters(): weight.fast = None fi_o = model(fi_i) loss = soft_cross_entropy(fi_o, lb_i) optimizer.zero_grad() grad = torch.autograd.grad(loss, model.parameters(), create_graph=True) # create fast weights so that we can use second-order gradient for k, weight in enumerate(model.parameters()): weight.fast = weight - args.inner_lr * grad[k] # outer loop # update y~ <-- y~ - beta nabla_y~ L(f_theta(x_t'), y_t') # fast weights will be used logit = model(input_) label_loss = criterion(logit, target) labels_opt.zero_grad() # retain_graph=True allows us to update the feature extractor # without calculating the loss again label_loss.backward(retain_graph=True) labels_opt.step() # normalize the labels to form a valid probability distribution labels.data = torch.clamp(labels.data, 0, 1) labels.data = labels.data / labels.data.sum(dim=1).unsqueeze(1) # now update the model optimizer.zero_grad() loss.backward() optimizer.step() model_losses.update(loss.item(), input_.size(0)) # measure error and record loss err1 = compute_error_rate(logit.data, target, topk=(1, ))[0] # it returns a list losses.update(label_loss.item(), input_.size(0)) top1.update(err1.item(), input_.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # update the moving average statistics if len(ma_list) < stats_gap: ma_list.append(err1.item()) ma_sum += err1.item() current_num_steps += 1 current_ma = ma_sum / len(ma_list) if current_num_steps == stats_gap: lowest_ma_sum = ma_sum num_steps_from_min = 0 else: ma_sum = ma_sum - ma_list[0] + err1.item() ma_list = ma_list[1:] + [err1.item()] current_num_steps += 1 current_ma = ma_sum / len(ma_list) if ma_sum < lowest_ma_sum: lowest_ma_sum = ma_sum num_steps_from_min = 0 elif num_steps_from_min < stats_gap: num_steps_from_min += 1 else: # do early stopping num_steps_list.append(current_num_steps - num_steps_from_min - 1) # restart all metrics ma_list = [] ma_sum = 0 lowest_ma_sum = 999999999 current_num_steps = 0 num_steps_from_min = 0 # restart the model and the optimizer if args.target == "cifar10" or args.target == "cifar100": if args.resnet: model = M.ResNetMeta(dataset=args.target, depth=18, num_classes=num_classes, bottleneck=False, device=device).to(device=device) else: model = M.AlexCifarNetMeta(args).to(device=device) else: model = M.LeNetMeta(args).to(device=device) optimizer = torch.optim.Adam(model.parameters()) print('Model restarted after ' + str(num_steps_list[-1]) + ' steps') if i % args.print_freq == 0 and args.verbose is True: print('Epoch: [{0}/{1}][{2}/{3}]\t' 'LR: {LR:.6f}\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 'MAvg 1-err {current_ma:.4f}'.format(epoch, args.epochs, i, len(train_loader), LR=1, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, current_ma=current_ma)) print( '* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Train Loss {loss.avg:.3f}' .format(epoch, args.epochs, top1=top1, loss=losses)) return top1.avg, losses.avg, labels, model_losses.avg, \ ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, \ model, optimizer
def main(): global args, best_err1, device, num_classes args = get_args() torch.manual_seed(args.random_seed) # most cases have 10 classes # if there are more, then it will be reassigned num_classes = 10 best_err1 = 100 # define datasets if args.target == "mnist": transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) train_set_all = datasets.MNIST('data', train=True, transform=transform_train, target_transform=None, download=True) # set aside 10000 examples from the training set for validation train_set, val_set = torch.utils.data.random_split( train_set_all, [50000, 10000]) # if we do experiments with variable target set size, this will take care of it # by default the target set size is 50000 target_set_size = min(50000, args.target_set_size) train_set, _ = torch.utils.data.random_split( train_set, [target_set_size, 50000 - target_set_size]) test_set = datasets.MNIST('data', train=False, transform=transform_test, target_transform=None, download=True) elif args.target == "kmnist": transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) train_set_all = datasets.KMNIST('data', train=True, transform=transform_train, target_transform=None, download=True) # set aside 10000 examples from the training set for validation train_set, val_set = torch.utils.data.random_split( train_set_all, [50000, 10000]) target_set_size = min(50000, args.target_set_size) # if we do experiments with variable target set size, this will take care of it train_set, _ = torch.utils.data.random_split( train_set, [target_set_size, 50000 - target_set_size]) test_set = datasets.KMNIST('data', train=False, transform=transform_test, target_transform=None, download=True) elif args.target == "k49": num_classes = 49 train_images = np.load('./data/k49-train-imgs.npz')['arr_0'] test_images = np.load('./data/k49-test-imgs.npz')['arr_0'] train_labels = np.load('./data/k49-train-labels.npz')['arr_0'] test_labels = np.load('./data/k49-test-labels.npz')['arr_0'] transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]) # set aside about 10% of training data for validation train_set_all = K49Dataset(train_images, train_labels, transform=transform_train) train_set, val_set = torch.utils.data.random_split( train_set_all, [209128, 23237]) # currently we do not support variable target set size for k49 # enable this to use it # target_set_size = min(209128, args.target_set_size) # train_set, _ = torch.utils.data.random_split( # train_set, [target_set_size, 209128 - target_set_size]) test_set = K49Dataset(test_images, test_labels, transform=transform_test) elif args.target == "cifar10": normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) transform_train = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor(), normalize]) transform_test = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor(), normalize]) train_set_all = datasets.CIFAR10('data', train=True, transform=transform_train, target_transform=None, download=True) # set aside 5000 examples from the training set for validation train_set, val_set = torch.utils.data.random_split( train_set_all, [45000, 5000]) # if we do experiments with variable target set size, this will take care of it target_set_size = min(45000, args.target_set_size) train_set, _ = torch.utils.data.random_split( train_set, [target_set_size, 45000 - target_set_size]) test_set = datasets.CIFAR10('data', train=False, transform=transform_test, target_transform=None, download=True) elif args.target == "cifar100": num_classes = 100 normalize = transforms.Normalize( mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) transform_train = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor(), normalize]) transform_test = transforms.Compose( [transforms.Resize((32, 32)), transforms.ToTensor(), normalize]) train_set_all = datasets.CIFAR100('data', train=True, transform=transform_train, target_transform=None, download=True) # set aside 5000 examples from the training set for validation train_set, val_set = torch.utils.data.random_split( train_set_all, [45000, 5000]) # if we do experiments with variable target set size, this will take care of it target_set_size = min(45000, args.target_set_size) train_set, _ = torch.utils.data.random_split( train_set, [target_set_size, 45000 - target_set_size]) test_set = datasets.CIFAR100('data', train=False, transform=transform_test, target_transform=None, download=True) # create data loaders if args.baseline: train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.num_base_examples, shuffle=True) else: train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False) test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False) # create data loaders to get base examples if args.source == "emnist": train_set_source = datasets.EMNIST('data', 'letters', train=True, download=True, transform=transform_train, target_transform=None) train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) elif args.source == "mnist": train_set_source = datasets.MNIST('data', train=True, download=True, transform=transform_train, target_transform=None) train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) elif args.source == "kmnist": train_set_source = datasets.KMNIST('data', train=True, download=True, transform=transform_train, target_transform=None) train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) elif args.source == "cifar10": train_set_source = datasets.CIFAR10('data', train=True, download=True, transform=transform_train, target_transform=None) train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) elif args.source == "cifar100": train_set_source = datasets.CIFAR100('data', train=True, download=True, transform=transform_train, target_transform=None) train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) elif args.source == "svhn": train_set_source = datasets.SVHN('data', split='train', download=True, transform=transform_train, target_transform=None) train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) elif args.source == "cub": # modify the root depending on where you place the images cub_data_root = './data/CUB_200_2011/images' train_set_source = datasets.ImageFolder(cub_data_root, transform=transform_train, target_transform=None) train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) elif args.source == "fake": # there is also an option to use random noise base examples if args.target == "mnist": num_channels = 1 dims = 28 else: num_channels = 3 dims = 32 train_set_source = datasets.FakeData(size=5000, image_size=(num_channels, dims, dims), num_classes=10, transform=transform_train, target_transform=None, random_offset=0) train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) else: # get the fixed images from the same dataset as the training data train_set_source = train_set train_loader_source = torch.utils.data.DataLoader( train_set_source, batch_size=args.num_base_examples, shuffle=True) if torch.cuda.is_available(): # checks whether a cuda gpu is available device = torch.cuda.current_device() print("use GPU", device) print("GPU ID {}".format(torch.cuda.current_device())) else: print("use CPU") device = torch.device('cpu') # sets the device to be CPU train_loader_source_iter = iter(train_loader_source) if args.balanced_source: # use a balanced set of fixed examples - same number of examples per class class_counts = {} fixed_input = [] fixed_target = [] for batch_fixed_i, batch_fixed_t in train_loader_source_iter: if sum(class_counts.values()) >= args.num_base_examples: break for fixed_i, fixed_t in zip(batch_fixed_i, batch_fixed_t): if len(class_counts.keys()) < num_classes: if int(fixed_t) in class_counts: if class_counts[ int(fixed_t )] < args.num_base_examples // num_classes: class_counts[int(fixed_t)] += 1 fixed_input.append(fixed_i) fixed_target.append(int(fixed_t)) else: class_counts[int(int(fixed_t))] = 1 fixed_input.append(fixed_i) fixed_target.append(int(fixed_t)) else: if int(fixed_t) in class_counts: if class_counts[ int(fixed_t )] < args.num_base_examples // num_classes: class_counts[int(fixed_t)] += 1 fixed_input.append(fixed_i) fixed_target.append(int(fixed_t)) fixed_input = torch.stack(fixed_input).to(device=device) fixed_target = torch.Tensor(fixed_target).to(device=device) else: # used for cross-dataset scenario - random selection of classes # not taking into accound the original classes fixed_input, fixed_target = next(train_loader_source_iter) fixed_input = fixed_input.to(device=device) fixed_target = fixed_target.to(device=device) # define loss function (criterion) criterion = nn.CrossEntropyLoss().to(device=device) # start at uniform labels and then learn them labels = torch.zeros((args.num_base_examples, num_classes), requires_grad=True, device=device) labels = labels.new_tensor( [[float(1.0 / num_classes) for e in range(num_classes)] for i in range(args.num_base_examples)], requires_grad=True, device=device) # define an optimizer for labels labels_opt = torch.optim.Adam([labels]) # enable using meta-architectures for second-order meta-learning # allows assigning fast weights cudnn.benchmark = True M.LeNetMeta.meta = True M.AlexCifarNetMeta.meta = True M.BasicBlockMeta.meta = True M.BottleneckMeta.meta = True M.ResNetMeta.meta = True # define the models to use if args.target == "cifar10" or args.target == "cifar100": if args.resnet: model = M.ResNetMeta(dataset=args.target, depth=18, num_classes=num_classes, bottleneck=False, device=device).to(device=device) model_name = 'resnet' else: model = M.AlexCifarNetMeta(args).to(device=device) model_name = 'alexnet' else: model = M.LeNetMeta(args).to(device=device) model_name = 'LeNet' optimizer = torch.optim.Adam(model.parameters()) if args.baseline: create_json_experiment_log(fixed_target) # remap the targets - only relevant in cross-dataset fixed_target = remap_targets(fixed_target, num_classes) # printing the labels helps ensure the seeds work print('The labels of the fixed examples are') print(fixed_target.tolist()) labels = one_hot(fixed_target.long(), num_classes) # add smoothing to the baseline if selected if args.label_smoothing > 0: labels = create_smooth_labels(labels, args.label_smoothing, num_classes) # use the validation set to find a suitable number of iterations for training num_baseline_steps, errors_list, num_steps_used = find_best_num_steps( val_loader, criterion, fixed_input, labels) print('Number of steps to use for the baseline: ' + str(num_baseline_steps)) experiment_update_dict = { 'num_baseline_steps': num_baseline_steps, 'errors_list': errors_list, 'num_steps_used': num_steps_used } update_json_experiment_log_dict(experiment_update_dict) if args.test_various_models: assert args.target == "cifar10", "test various models is only meant to be used for CIFAR-10" model_name_list = ['alexnet', 'LeNet', 'resnet'] for model_name_test in model_name_list: # do 20 repetitions of training from scratch for test_i in range(20): print('Test repetition ' + str(test_i)) test_err1, test_loss = test(test_loader, model_name_test, criterion, fixed_input, labels, num_baseline_steps) print('Test error (top-1 error):', test_err1) experiment_update_dict = { 'test_top_1_error_' + model_name_test: test_err1, 'test_loss_' + model_name_test: test_loss, 'num_test_steps_' + model_name_test: num_baseline_steps } update_json_experiment_log_dict(experiment_update_dict) else: # do 20 repetitions of training from scratch for test_i in range(20): print('Test repetition ' + str(test_i)) test_err1, test_loss = test(test_loader, model_name, criterion, fixed_input, labels, num_baseline_steps) print('Test error (top-1 error):', test_err1) experiment_update_dict = { 'test_top_1_error': test_err1, 'test_loss': test_loss, 'num_test_steps': num_baseline_steps } update_json_experiment_log_dict(experiment_update_dict) else: create_json_experiment_log(fixed_target) # start measuring time start_time = time.time() # initialize variables to decide when to restart a model ma_list = [] ma_sum = 0 lowest_ma_sum = 999999999 current_num_steps = 0 num_steps_list = [] num_steps_from_min = 0 val_err1 = 100.0 val_loss = 5.0 num_steps_val = 0 with tqdm.tqdm(total=args.epochs) as pbar_epochs: for epoch in range(0, args.epochs): train_err1, train_loss, labels, model_loss, ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min, model, optimizer = \ train(train_loader, model, fixed_input, labels, criterion, labels_opt, epoch, optimizer, ma_list, ma_sum, lowest_ma_sum, current_num_steps, num_steps_list, num_steps_from_min) # evaluate on the validation set only every 5 epochs as it can be quite expensive to train a new model from scratch if epoch % 5 == 4: # calculate the number of steps to use if len(num_steps_list) == 0: num_steps_val = current_num_steps else: num_steps_val = int(np.mean(num_steps_list[-3:])) val_err1, val_loss = validate(val_loader, model, criterion, epoch, fixed_input, labels, num_steps_val) if val_err1 <= best_err1: best_labels = labels.detach().clone() best_num_steps = num_steps_val best_err1 = min(val_err1, best_err1) print('Current best val error (top-1 error):', best_err1) pbar_epochs.update(1) experiment_update_dict = { 'train_top_1_error': train_err1, 'train_loss': train_loss, 'val_top_1_error': val_err1, 'val_loss': val_loss, 'model_loss': model_loss, 'epoch': epoch, 'num_val_steps': num_steps_val } # save the best labels so that we can analyse them if epoch == args.epochs - 1: experiment_update_dict['labels'] = best_labels.tolist() update_json_experiment_log_dict(experiment_update_dict) print('Best val error (top-1 error):', best_err1) # stop measuring time experiment_update_dict = {'total_train_time': time.time() - start_time} update_json_experiment_log_dict(experiment_update_dict) # this does number of steps analysis - what happens if we do more or fewer steps for test training if args.num_steps_analysis: num_steps_add = [-50, -20, -10, 0, 10, 20, 50, 100] for num_steps_add_item in num_steps_add: # start measuring time for testing start_time = time.time() local_errs = [] local_losses = [] local_num_steps = best_num_steps + num_steps_add_item print('Number of steps for training: ' + str(local_num_steps)) # each number of steps will have a robust estimate by using 20 repetitions for test_i in range(20): print('Test repetition ' + str(test_i)) test_err1, test_loss = test(test_loader, model_name, criterion, fixed_input, best_labels, local_num_steps) local_errs.append(test_err1) local_losses.append(test_loss) print('Test error (top-1 error):', test_err1) experiment_update_dict = { 'test_top_1_error': local_errs, 'test_loss': local_losses, 'total_test_time': time.time() - start_time, 'num_test_steps': local_num_steps } update_json_experiment_log_dict(experiment_update_dict) else: if args.test_various_models: assert args.target == "cifar10", "test various models is only meant to be used for CIFAR-10" model_name_list = ['alexnet', 'LeNet', 'resnet'] for model_name_test in model_name_list: for test_i in range(20): print(model_name_test) print('Test repetition ' + str(test_i)) test_err1, test_loss = test(test_loader, model_name_test, criterion, fixed_input, best_labels, best_num_steps) print('Test error (top-1 error):', test_err1) experiment_update_dict = { 'test_top_1_error_' + model_name_test: test_err1, 'test_loss_' + model_name_test: test_loss, 'total_test_time_' + model_name_test: time.time() - start_time, 'num_test_steps_' + model_name_test: best_num_steps } update_json_experiment_log_dict(experiment_update_dict) else: for test_i in range(20): print('Test repetition ' + str(test_i)) test_err1, test_loss = test(test_loader, model_name, criterion, fixed_input, best_labels, best_num_steps) print('Test error (top-1 error):', test_err1) experiment_update_dict = { 'test_top_1_error': test_err1, 'test_loss': test_loss, 'total_test_time': time.time() - start_time, 'num_test_steps': best_num_steps } update_json_experiment_log_dict(experiment_update_dict)