def ce_gradient_pair_scatter( model, data_loader, d1=0, d2=1, max_num_examples=2000, plt=None ): if plt is None: plt = matplotlib.pyplot model.eval() pred = utils.apply_on_dataset( model=model, dataset=data_loader.dataset, output_keys_regexp="pred", max_num_examples=max_num_examples, description="grad-pair-scatter:pred", )["pred"] n_examples = min(len(data_loader.dataset), max_num_examples) labels = [] for idx in range(n_examples): labels.append(data_loader.dataset[idx][1]) labels = torch.tensor(labels, dtype=torch.long) labels = F.one_hot(labels, num_classes=model.num_classes).float() labels = utils.to_cpu(labels) grad_wrt_logits = torch.softmax(pred, dim=-1) - labels grad_wrt_logits = utils.to_numpy(grad_wrt_logits) fig, ax = plt.subplots(1, figsize=(5, 5)) plt.scatter(grad_wrt_logits[:, d1], grad_wrt_logits[:, d2]) ax.set_xlabel(str(d1)) ax.set_ylabel(str(d2)) # L = np.percentile(grad_wrt_logits, q=5, axis=0) # R = np.percentile(grad_wrt_logits, q=95, axis=0) # ax.set_xlim(L[d1], R[d1]) # ax.set_ylim(L[d2], R[d2]) ax.set_title("Two coordinates of grad wrt to logits") return fig, plt
def ce_gradient_norm_histogram( model, data_loader, tensorboard, epoch, name, max_num_examples=5000 ): model.eval() pred = utils.apply_on_dataset( model=model, dataset=data_loader.dataset, output_keys_regexp="pred", description="grad-histogram:pred", max_num_examples=max_num_examples, )["pred"] n_examples = min(len(data_loader.dataset), max_num_examples) labels = [] for idx in range(n_examples): labels.append(data_loader.dataset[idx][1]) labels = torch.tensor(labels, dtype=torch.long) labels = F.one_hot(labels, num_classes=model.num_classes).float() labels = utils.to_cpu(labels) grad_wrt_logits = torch.softmax(pred, dim=-1) - labels grad_norms = torch.sum(grad_wrt_logits ** 2, dim=-1) grad_norms = utils.to_numpy(grad_norms) try: tensorboard.add_histogram(tag=name, values=grad_norms, global_step=epoch) except ValueError as e: print("Tensorboard histogram error: {}".format(e))
def plot_predictions(model, data_loader, key, plt=None): if plt is None: plt = matplotlib.pyplot model.eval() n_examples = 10 pred = utils.apply_on_dataset( model=model, dataset=data_loader.dataset, output_keys_regexp=key, max_num_examples=n_examples, description='plot_predictions:{}'.format(key))[key] probs = torch.softmax(pred, dim=1) probs = utils.to_numpy(probs) data = [data_loader.dataset[i][0] for i in range(n_examples)] labels = [data_loader.dataset[i][1] for i in range(n_examples)] samples = torch.stack(data, dim=0) samples = revert_normalization(samples, data_loader.dataset) samples = utils.to_numpy(samples) fig, ax = plt.subplots(nrows=n_examples, ncols=2, figsize=(2 * 2, 2 * n_examples)) for i in range(n_examples): ax[i][0].imshow(get_image(samples[i]), vmin=0, vmax=1) ax[i][0].set_axis_off() ax[i][0].set_title('labels as {}'.format(labels[i])) ax[i][1].bar(range(model.num_classes), probs[i]) ax[i][1].set_xticks(range(model.num_classes)) return fig, plt
def pred_gradient_pair_scatter(model, data_loader, d1=0, d2=1, max_num_examples=2000, plt=None): if plt is None: plt = matplotlib.pyplot model.eval() grad_pred = utils.apply_on_dataset( model=model, dataset=data_loader.dataset, output_keys_regexp='grad_pred', max_num_examples=max_num_examples, description='grad-pair-scatter:grad_pred')['grad_pred'] grad_pred = utils.to_numpy(grad_pred) fig, ax = plt.subplots(1, figsize=(5, 5)) plt.scatter(grad_pred[:, d1], grad_pred[:, d2]) ax.set_xlabel(str(d1)) ax.set_ylabel(str(d2)) # L = np.percentile(grad_pred, q=5, axis=0) # R = np.percentile(grad_pred, q=95, axis=0) # ax.set_xlim(L[d1], R[d1]) # ax.set_ylim(L[d2], R[d2]) ax.set_title('Two coordinates of grad wrt to logits') return fig, plt
def main(): parser = argparse.ArgumentParser() parser.add_argument('--device', '-d', default='cuda') parser.add_argument('--batch_size', '-b', type=int, default=256) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--dataset', '-D', type=str, default='mnist', choices=['mnist', 'cifar10', 'cifar100', 'clothing1m', 'imagenet']) parser.add_argument('--data_augmentation', '-A', action='store_true', dest='data_augmentation') parser.set_defaults(data_augmentation=False) parser.add_argument('--num_train_examples', type=int, default=None) parser.add_argument('--label_noise_level', '-n', type=float, default=0.0) parser.add_argument('--label_noise_type', type=str, default='flip', choices=['flip', 'error', 'cifar10_custom']) parser.add_argument('--transform_function', type=str, default=None, choices=[None, 'remove_random_chunks']) parser.add_argument('--clean_validation', dest='clean_validation', action='store_true') parser.set_defaults(clean_validation=False) parser.add_argument('--remove_prob', type=float, default=0.5) parser.add_argument('--load_from', type=str, default=None, required=True) parser.add_argument('--output_dir', '-o', type=str, default=None) args = parser.parse_args() print(args) # Load data _, _, test_loader = datasets.load_data_from_arguments(args) print(f"Testing the model saved at {args.load_from}") model = utils.load(args.load_from, device=args.device) ret = utils.apply_on_dataset(model, test_loader.dataset, batch_size=args.batch_size, output_keys_regexp='pred|label', description='Testing') pred = ret['pred'] labels = ret['label'] if args.output_dir is not None: with open(os.path.join(args.output_dir, 'test_predictions.pkl'), 'wb') as f: pickle.dump({'pred': pred, 'labels': labels}, f) accuracy = torch.mean((pred.argmax(dim=1) == labels).float()) print(accuracy) if args.output_dir is not None: with open(os.path.join(args.output_dir, 'test_accuracy.txt'), 'w') as f: f.write("{}\n".format(accuracy))
def pred_gradient_norm_histogram( model, data_loader, tensorboard, epoch, name, max_num_examples=5000 ): model.eval() grad_pred = utils.apply_on_dataset( model=model, dataset=data_loader.dataset, output_keys_regexp="grad_pred", description="grad-histogram:grad_pred", max_num_examples=max_num_examples, )["grad_pred"] grad_norms = torch.sum(grad_pred ** 2, dim=-1) grad_norms = utils.to_numpy(grad_norms) try: tensorboard.add_histogram(tag=name, values=grad_norms, global_step=epoch) except ValueError as e: print("Tensorboard histogram error: {}".format(e))
def estimate_transition(load_from, data_loader, device="cpu", batch_size=256): """ Estimates the label noise matrix. The code is adapted form the original implementation. Source: https://github.com/giorgiop/loss-correction/. """ assert load_from is not None model = utils.load(load_from, device=device) pred = utils.apply_on_dataset( model=model, dataset=data_loader.dataset, batch_size=batch_size, cpu=True, description="Estimating transition matrix", output_keys_regexp="pred", )["pred"] pred = torch.softmax(pred, dim=1) pred = utils.to_numpy(pred) c = model.num_classes T = np.zeros((c, c)) filter_outlier = True # find a 'perfect example' for each class for i in range(c): if not filter_outlier: idx_best = np.argmax(pred[:, i]) else: thresh = np.percentile(pred[:, i], 97, interpolation="higher") robust_eta = pred[:, i] robust_eta[robust_eta >= thresh] = 0.0 idx_best = np.argmax(robust_eta) for j in range(c): T[i, j] = pred[idx_best, j] # row normalize row_sums = T.sum(axis=1, keepdims=True) T /= row_sums T = torch.tensor(T, dtype=torch.float).to(device) print(T) return T
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', '-c', type=str, required=True) parser.add_argument('--device', '-d', default='cuda') parser.add_argument('--batch_size', '-b', type=int, default=256) parser.add_argument('--epochs', '-e', type=int, default=400) parser.add_argument('--stopping_param', type=int, default=50) parser.add_argument('--save_iter', '-s', type=int, default=10) parser.add_argument('--vis_iter', '-v', type=int, default=10) parser.add_argument('--log_dir', '-l', type=str, default=None) parser.add_argument('--seed', type=int, default=42) parser.add_argument( '--dataset', '-D', type=str, default='mnist', choices=['mnist', 'cifar10', 'cifar100', 'clothing1m', 'imagenet']) parser.add_argument('--data_augmentation', '-A', action='store_true', dest='data_augmentation') parser.set_defaults(data_augmentation=False) parser.add_argument('--num_train_examples', type=int, default=None) parser.add_argument('--label_noise_level', '-n', type=float, default=0.0) parser.add_argument('--label_noise_type', type=str, default='error', choices=['error', 'cifar10_custom']) parser.add_argument('--transform_function', type=str, default=None, choices=[None, 'remove_random_chunks']) parser.add_argument('--clean_validation', dest='clean_validation', action='store_true') parser.set_defaults(clean_validation=False) parser.add_argument('--remove_prob', type=float, default=0.5) parser.add_argument('--model_class', '-m', type=str, default='StandardClassifier') parser.add_argument( '--loss_function', type=str, default='ce', choices=['ce', 'mse', 'mae', 'gce', 'dmi', 'fw', 'none']) parser.add_argument('--loss_function_param', type=float, default=1.0) parser.add_argument('--load_from', type=str, default=None) parser.add_argument('--grad_weight_decay', '-L', type=float, default=0.0) parser.add_argument('--grad_l1_penalty', '-S', type=float, default=0.0) parser.add_argument('--lamb', type=float, default=1.0) parser.add_argument('--pretrained_arg', '-r', type=str, default=None) parser.add_argument('--sample_from_q', action='store_true', dest='sample_from_q') parser.set_defaults(sample_from_q=False) parser.add_argument('--q_dist', type=str, default='Gaussian', choices=['Gaussian', 'Laplace', 'dot']) parser.add_argument('--no-detach', dest='detach', action='store_false') parser.set_defaults(detach=True) parser.add_argument('--warm_up', type=int, default=0, help='Number of epochs to skip before ' 'starting to train using predicted gradients') parser.add_argument('--weight_decay', type=float, default=0.0) parser.add_argument( '--add_noise', action='store_true', dest='add_noise', help='add noise to the gradients of a standard classifier.') parser.set_defaults(add_noise=False) parser.add_argument('--noise_type', type=str, default='Gaussian', choices=['Gaussian', 'Laplace']) parser.add_argument('--noise_std', type=float, default=0.0) parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') args = parser.parse_args() print(args) # Load data train_loader, val_loader, test_loader = datasets.load_data_from_arguments( args) # Options optimization_args = { 'optimizer': { 'name': 'adam', 'lr': args.lr, 'weight_decay': args.weight_decay } } # optimization_args = { # 'optimizer': { # 'name': 'sgd', # 'lr': 1e-3, # }, # 'scheduler': { # 'step_size': 15, # 'gamma': 1.25 # } # } with open(args.config, 'r') as f: architecture_args = json.load(f) model_class = getattr(methods, args.model_class) model = model_class(input_shape=train_loader.dataset[0][0].shape, architecture_args=architecture_args, pretrained_arg=args.pretrained_arg, device=args.device, grad_weight_decay=args.grad_weight_decay, grad_l1_penalty=args.grad_l1_penalty, lamb=args.lamb, sample_from_q=args.sample_from_q, q_dist=args.q_dist, load_from=args.load_from, loss_function=args.loss_function, loss_function_param=args.loss_function_param, add_noise=args.add_noise, noise_type=args.noise_type, noise_std=args.noise_std, detach=args.detach, warm_up=args.warm_up) metrics_list = [] if args.dataset == 'imagenet': metrics_list.append(metrics.TopKAccuracy(k=5, output_key='pred')) training.train(model=model, train_loader=train_loader, val_loader=val_loader, epochs=args.epochs, save_iter=args.save_iter, vis_iter=args.vis_iter, optimization_args=optimization_args, log_dir=args.log_dir, args_to_log=args, stopping_param=args.stopping_param, metrics=metrics_list) # if training finishes successfully, compute the test score print("Testing the best validation model...") model = utils.load(os.path.join(args.log_dir, 'checkpoints', 'best_val.mdl'), device=args.device) pred = utils.apply_on_dataset(model, test_loader.dataset, batch_size=args.batch_size, output_keys_regexp='pred', description='Testing')['pred'] labels = [p[1] for p in test_loader.dataset] labels = torch.tensor(labels, dtype=torch.long) labels = utils.to_cpu(labels) with open(os.path.join(args.log_dir, 'test_predictions.pkl'), 'wb') as f: pickle.dump({'pred': pred, 'labels': labels}, f) accuracy = torch.mean((pred.argmax(dim=1) == labels).float()) with open(os.path.join(args.log_dir, 'test_accuracy.txt'), 'w') as f: f.write("{}\n".format(accuracy))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--device", "-d", default="cuda") parser.add_argument("--batch_size", "-b", type=int, default=256) parser.add_argument("--seed", type=int, default=42) parser.add_argument( "--dataset", "-D", type=str, default="mnist", choices=["mnist", "cifar10", "cifar100", "clothing1m", "imagenet"], ) parser.add_argument( "--data_augmentation", "-A", action="store_true", dest="data_augmentation" ) parser.set_defaults(data_augmentation=False) parser.add_argument("--num_train_examples", type=int, default=None) parser.add_argument("--label_noise_level", "-n", type=float, default=0.0) parser.add_argument( "--label_noise_type", type=str, default="flip", choices=["flip", "error", "cifar10_custom"], ) parser.add_argument( "--transform_function", type=str, default=None, choices=[None, "remove_random_chunks"], ) parser.add_argument( "--clean_validation", dest="clean_validation", action="store_true" ) parser.set_defaults(clean_validation=False) parser.add_argument("--remove_prob", type=float, default=0.5) parser.add_argument("--load_from", type=str, default=None, required=True) parser.add_argument("--output_dir", "-o", type=str, default=None) args = parser.parse_args() print(args) # Load data _, _, test_loader = datasets.load_data_from_arguments(args) print(f"Testing the model saved at {args.load_from}") model = utils.load(args.load_from, device=args.device) ret = utils.apply_on_dataset( model, test_loader.dataset, batch_size=args.batch_size, output_keys_regexp="pred|label", description="Testing", ) pred = ret["pred"] labels = ret["label"] if args.output_dir is not None: with open(os.path.join(args.output_dir, "test_predictions.pkl"), "wb") as f: pickle.dump({"pred": pred, "labels": labels}, f) accuracy = torch.mean((pred.argmax(dim=1) == labels).float()) print(accuracy) if args.output_dir is not None: with open(os.path.join(args.output_dir, "test_accuracy.txt"), "w") as f: f.write("{}\n".format(accuracy))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", type=str, required=True) parser.add_argument("--device", "-d", default="cuda") parser.add_argument("--batch_size", "-b", type=int, default=128) parser.add_argument("--epochs", "-e", type=int, default=4000) parser.add_argument("--stopping_param", type=int, default=2**30) parser.add_argument("--save_iter", "-s", type=int, default=100) parser.add_argument("--vis_iter", "-v", type=int, default=10) parser.add_argument("--log_dir", "-l", type=str, default=None) parser.add_argument("--seed", type=int, default=42) parser.add_argument( "--dataset", "-D", type=str, default="cifar10", choices=["mnist", "cifar10", "cifar100", "clothing1m", "imagenet"], ) parser.add_argument("--data_augmentation", "-A", action="store_true", dest="data_augmentation") parser.set_defaults(data_augmentation=False) parser.add_argument("--num_train_examples", type=int, default=None) parser.add_argument("--label_noise_level", "-n", type=float, default=0.0) parser.add_argument( "--label_noise_type", type=str, default="error", choices=["error", "cifar10_custom"], ) parser.add_argument( "--transform_function", type=str, default=None, choices=[None, "remove_random_chunks"], ) parser.add_argument("--clean_validation", dest="clean_validation", action="store_true") parser.set_defaults(clean_validation=False) parser.add_argument("--remove_prob", type=float, default=0.5) parser.add_argument("--model_class", "-m", type=str, default="StandardClassifier") parser.add_argument("--load_from", type=str, default=None) parser.add_argument("--grad_weight_decay", "-L", type=float, default=0.0) parser.add_argument("--lamb", type=float, default=1.0) parser.add_argument("--pretrained_arg", "-r", type=str, default=None) parser.add_argument("--sample_from_q", action="store_true", dest="sample_from_q") parser.set_defaults(sample_from_q=False) parser.add_argument("--q_dist", type=str, default="Gaussian", choices=["Gaussian", "Laplace", "dot"]) parser.add_argument("--weight_decay", type=float, default=0.0) parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument( "--k", "-k", type=int, required=False, default=10, help="width parameter of ResNet18-k", ) parser.add_argument("--exclude_percent", type=float, default=0.0) args = parser.parse_args() print(args) # Load data train_loader, val_loader, test_loader = datasets.load_data_from_arguments( args) # Options optimization_args = { "optimizer": { "name": "adam", "lr": args.lr, "weight_decay": args.weight_decay } } with open(args.config, "r") as f: architecture_args = json.load(f) # set the width parameter k if ("classifier" in architecture_args and architecture_args["classifier"].get( "net", "").find("double-descent") != -1): architecture_args["classifier"]["k"] = args.k if ("q-network" in architecture_args and architecture_args["classifier"].get( "net", "").find("double-descent") != -1): architecture_args["q-network"]["k"] = args.k model_class = getattr(methods, args.model_class) model = model_class( input_shape=train_loader.dataset[0][0].shape, architecture_args=architecture_args, pretrained_arg=args.pretrained_arg, device=args.device, grad_weight_decay=args.grad_weight_decay, lamb=args.lamb, sample_from_q=args.sample_from_q, q_dist=args.q_dist, load_from=args.load_from, loss_function="ce", ) training.train( model=model, train_loader=train_loader, val_loader=val_loader, epochs=args.epochs, save_iter=args.save_iter, vis_iter=args.vis_iter, optimization_args=optimization_args, log_dir=args.log_dir, args_to_log=args, stopping_param=args.stopping_param, ) # test the last model and best model models_to_test = [ { "name": "best", "file": "best_val.mdl" }, { "name": "final", "file": "final.mdl" }, ] for spec in models_to_test: print("Testing the {} model...".format(spec["name"])) model = utils.load(os.path.join(args.log_dir, "checkpoints", spec["file"]), device=args.device) pred = utils.apply_on_dataset( model, test_loader.dataset, batch_size=args.batch_size, output_keys_regexp="pred", description="Testing", )["pred"] labels = [p[1] for p in test_loader.dataset] labels = torch.tensor(labels, dtype=torch.long) labels = utils.to_cpu(labels) with open( os.path.join(args.log_dir, "{}_test_predictions.pkl".format(spec["name"])), "wb", ) as f: pickle.dump({"pred": pred, "labels": labels}, f) accuracy = torch.mean((pred.argmax(dim=1) == labels).float()) with open( os.path.join(args.log_dir, "{}_test_accuracy.txt".format(spec["name"])), "w") as f: f.write("{}\n".format(accuracy))
def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", type=str, default=None) parser.add_argument("--device", "-d", default="cuda") parser.add_argument("--batch_size", "-b", type=int, default=256) parser.add_argument("--epochs", "-e", type=int, default=400) parser.add_argument("--stopping_param", type=int, default=50) parser.add_argument("--save_iter", "-s", type=int, default=10) parser.add_argument("--vis_iter", "-v", type=int, default=10) parser.add_argument("--log_dir", "-l", type=str, default=None) parser.add_argument("--seed", type=int, default=42) parser.add_argument( "--dataset", "-D", type=str, default="mnist", choices=[ "mnist", "cifar10", "cifar100", "clothing1m", "imagenet", "cover" ], ) # parser.add_argument("--image-root", metavar="DIR", help="path to images") # parser.add_argument("--label", metavar="DIR", help="path to label file") parser.add_argument("--data_augmentation", "-A", action="store_true", dest="data_augmentation") parser.set_defaults(data_augmentation=False) parser.add_argument("--num_train_examples", type=int, default=None) parser.add_argument("--label_noise_level", "-n", type=float, default=0.0) parser.add_argument( "--label_noise_type", type=str, default="error", choices=["error", "cifar10_custom"], ) parser.add_argument( "--transform_function", type=str, default=None, choices=[None, "remove_random_chunks"], ) parser.add_argument("--clean_validation", dest="clean_validation", action="store_true") parser.set_defaults(clean_validation=False) parser.add_argument("--remove_prob", type=float, default=0.5) parser.add_argument("--model_class", "-m", type=str, default="StandardClassifier") parser.add_argument( "--loss_function", type=str, default="ce", choices=["ce", "mse", "mae", "gce", "dmi", "fw", "none"], ) parser.add_argument("--loss_function_param", type=float, default=1.0) parser.add_argument("--load_from", type=str, default=None) parser.add_argument("--grad_weight_decay", "-L", type=float, default=0.0) parser.add_argument("--grad_l1_penalty", "-S", type=float, default=0.0) parser.add_argument("--lamb", type=float, default=1.0) parser.add_argument("--pretrained_arg", "-r", type=str, default=None) parser.add_argument("--sample_from_q", action="store_true", dest="sample_from_q") parser.set_defaults(sample_from_q=False) parser.add_argument("--q_dist", type=str, default="Gaussian", choices=["Gaussian", "Laplace", "dot"]) parser.add_argument("--no-detach", dest="detach", action="store_false") parser.set_defaults(detach=True) parser.add_argument( "--warm_up", type=int, default=0, help="Number of epochs to skip before " "starting to train using predicted gradients", ) parser.add_argument("--weight_decay", type=float, default=0.0) parser.add_argument( "--add_noise", action="store_true", dest="add_noise", help="add noise to the gradients of a standard classifier.", ) parser.set_defaults(add_noise=False) parser.add_argument("--noise_type", type=str, default="Gaussian", choices=["Gaussian", "Laplace"]) parser.add_argument("--noise_std", type=float, default=0.0) parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") args = parser.parse_args() print(args) # Load data args.image_root = [ "/data/chenlong.1024/5198852-tiktok-1w_images", "/data/chenlong.1024/5205046-tiktok-10w_images", "/data/chenlong.1024/5599074-tiktok-impr_cnt20_images", "/data/chenlong.1024/5600297-tiktok-impr_cnt10_images", ] args.label = [ "/data/chenlong.1024/5198852-tiktok-1w.csv", "/data/chenlong.1024/5205046-tiktok-10w.csv", "/data/chenlong.1024/5599074-tiktok-impr_cnt20.csv", "/data/chenlong.1024/5600297-tiktok-impr_cnt10.csv", ] train_loader, val_loader, test_loader = datasets.load_data_from_arguments( args) # Options optimization_args = { "optimizer": { "name": "adam", "lr": args.lr, "weight_decay": args.weight_decay }, "scheduler": { "step_size": 20, "gamma": 0.3 }, } # optimization_args = { # 'optimizer': { # 'name': 'sgd', # 'lr': 1e-3, # }, # 'scheduler': { # 'step_size': 15, # 'gamma': 1.25 # } # } model_class = getattr(methods, args.model_class) if "CoverModel" in args.model_class: model = model_class( num_classes=2, pretrained=True, device=args.device, grad_weight_decay=args.grad_weight_decay, grad_l1_penalty=args.grad_l1_penalty, lamb=args.lamb, sample_from_q=args.sample_from_q, q_dist=args.q_dist, load_from=args.load_from, loss_function=args.loss_function, loss_function_param=args.loss_function_param, add_noise=args.add_noise, noise_type=args.noise_type, noise_std=args.noise_std, detach=args.detach, warm_up=args.warm_up, ) else: with open(args.config, "r") as f: architecture_args = json.load(f) model = model_class( input_shape=train_loader.dataset[0][0].shape, architecture_args=architecture_args, pretrained_arg=args.pretrained_arg, device=args.device, grad_weight_decay=args.grad_weight_decay, grad_l1_penalty=args.grad_l1_penalty, lamb=args.lamb, sample_from_q=args.sample_from_q, q_dist=args.q_dist, load_from=args.load_from, loss_function=args.loss_function, loss_function_param=args.loss_function_param, add_noise=args.add_noise, noise_type=args.noise_type, noise_std=args.noise_std, detach=args.detach, warm_up=args.warm_up, ) metrics_list = [] if args.dataset == "imagenet": metrics_list.append(metrics.TopKAccuracy(k=5, output_key="pred")) training.train( model=model, train_loader=train_loader, val_loader=val_loader, epochs=args.epochs, save_iter=args.save_iter, vis_iter=args.vis_iter, optimization_args=optimization_args, log_dir=args.log_dir, args_to_log=args, stopping_param=args.stopping_param, metrics=metrics_list, ) # if training finishes successfully, compute the test score print("Testing the best validation model...") model = utils.load(os.path.join(args.log_dir, "checkpoints", "best_val.mdl"), device=args.device) pred = utils.apply_on_dataset( model, test_loader.dataset, batch_size=args.batch_size, output_keys_regexp="pred", description="Testing", )["pred"] labels = [p[1] for p in test_loader.dataset] labels = torch.tensor(labels, dtype=torch.long) labels = utils.to_cpu(labels) with open(os.path.join(args.log_dir, "test_predictions.pkl"), "wb") as f: pickle.dump({"pred": pred, "labels": labels}, f) accuracy = torch.mean((pred.argmax(dim=1) == labels).float()) with open(os.path.join(args.log_dir, "test_accuracy.txt"), "w") as f: f.write("{}\n".format(accuracy))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', '-c', type=str, required=True) parser.add_argument('--device', '-d', default='cuda') parser.add_argument('--batch_size', '-b', type=int, default=128) parser.add_argument('--epochs', '-e', type=int, default=4000) parser.add_argument('--stopping_param', type=int, default=2**30) parser.add_argument('--save_iter', '-s', type=int, default=100) parser.add_argument('--vis_iter', '-v', type=int, default=10) parser.add_argument('--log_dir', '-l', type=str, default=None) parser.add_argument('--seed', type=int, default=42) parser.add_argument( '--dataset', '-D', type=str, default='cifar10', choices=['mnist', 'cifar10', 'cifar100', 'clothing1m', 'imagenet']) parser.add_argument('--data_augmentation', '-A', action='store_true', dest='data_augmentation') parser.set_defaults(data_augmentation=False) parser.add_argument('--num_train_examples', type=int, default=None) parser.add_argument('--label_noise_level', '-n', type=float, default=0.0) parser.add_argument('--label_noise_type', type=str, default='error', choices=['error', 'cifar10_custom']) parser.add_argument('--transform_function', type=str, default=None, choices=[None, 'remove_random_chunks']) parser.add_argument('--clean_validation', dest='clean_validation', action='store_true') parser.set_defaults(clean_validation=False) parser.add_argument('--remove_prob', type=float, default=0.5) parser.add_argument('--model_class', '-m', type=str, default='StandardClassifier') parser.add_argument('--load_from', type=str, default=None) parser.add_argument('--grad_weight_decay', '-L', type=float, default=0.0) parser.add_argument('--lamb', type=float, default=1.0) parser.add_argument('--pretrained_arg', '-r', type=str, default=None) parser.add_argument('--sample_from_q', action='store_true', dest='sample_from_q') parser.set_defaults(sample_from_q=False) parser.add_argument('--q_dist', type=str, default='Gaussian', choices=['Gaussian', 'Laplace', 'dot']) parser.add_argument('--weight_decay', type=float, default=0.0) parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') parser.add_argument('--k', '-k', type=int, required=True, default=10, help='width parameter of ResNet18-k') args = parser.parse_args() print(args) # Load data train_loader, val_loader, test_loader = datasets.load_data_from_arguments( args) # Options optimization_args = { 'optimizer': { 'name': 'adam', 'lr': args.lr, 'weight_decay': args.weight_decay } } with open(args.config, 'r') as f: architecture_args = json.load(f) # set the width parameter k if ('classifier' in architecture_args and architecture_args['classifier'].get( 'net', '') == 'double-descent-cifar10-resnet18'): architecture_args['classifier']['k'] = args.k if ('q-network' in architecture_args and architecture_args['q-network'].get( 'net', '') == 'double-descent-cifar10-resnet18'): architecture_args['q-network']['k'] = args.k model_class = getattr(methods, args.model_class) model = model_class(input_shape=train_loader.dataset[0][0].shape, architecture_args=architecture_args, pretrained_arg=args.pretrained_arg, device=args.device, grad_weight_decay=args.grad_weight_decay, lamb=args.lamb, sample_from_q=args.sample_from_q, q_dist=args.q_dist, load_from=args.load_from, loss_function='ce') training.train(model=model, train_loader=train_loader, val_loader=val_loader, epochs=args.epochs, save_iter=args.save_iter, vis_iter=args.vis_iter, optimization_args=optimization_args, log_dir=args.log_dir, args_to_log=args, stopping_param=args.stopping_param) # test the last model and best model models_to_test = [{ 'name': 'best', 'file': 'best_val.mdl' }, { 'name': 'final', 'file': 'final.mdl' }] for spec in models_to_test: print("Testing the {} model...".format(spec['name'])) model = utils.load(os.path.join(args.log_dir, 'checkpoints', spec['file']), device=args.device) pred = utils.apply_on_dataset(model, test_loader.dataset, batch_size=args.batch_size, output_keys_regexp='pred', description='Testing')['pred'] labels = [p[1] for p in test_loader.dataset] labels = torch.tensor(labels, dtype=torch.long) labels = utils.to_cpu(labels) with open( os.path.join(args.log_dir, '{}_test_predictions.pkl'.format(spec['name'])), 'wb') as f: pickle.dump({'pred': pred, 'labels': labels}, f) accuracy = torch.mean((pred.argmax(dim=1) == labels).float()) with open( os.path.join(args.log_dir, '{}_test_accuracy.txt'.format(spec['name'])), 'w') as f: f.write("{}\n".format(accuracy))