def run(net): # Create dataloaders trainloader, valloader = prepare_data() net = net.to(device) optimizer = torch.optim.SGD(net.parameters(), lr=hps['lr'], momentum=0.9, nesterov=True, weight_decay=0.0001) scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, verbose=True) criterion = nn.CrossEntropyLoss() best_acc_v = 0 print("Training", hps['name'], "on", device) for epoch in range(hps['n_epochs']): acc_tr, loss_tr = train(net, trainloader, criterion, optimizer) logger.loss_train.append(loss_tr) logger.acc_train.append(acc_tr) acc_v, loss_v = evaluate(net, valloader, criterion) logger.loss_val.append(loss_v) logger.acc_val.append(acc_v) # Update learning rate if plateau scheduler.step(acc_v) # Save logs regularly if (epoch + 1) % 5 == 0: logger.save(hps) # Save the best network and print results if acc_v > best_acc_v: save(net, hps) best_acc_v = acc_v print('Epoch %2d' % (epoch + 1), 'Train Accuracy: %2.2f %%' % acc_tr, 'Val Accuracy: %2.2f %%' % acc_v, 'Network Saved', sep='\t\t') else: print('Epoch %2d' % (epoch + 1), 'Train Accuracy: %2.2f %%' % acc_tr, 'Val Accuracy: %2.2f %%' % acc_v, sep='\t\t')
def main(): parser = argparse.ArgumentParser() # Settings parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True) parser.add_argument('-s', '--shift_type', choices=shift_types, required=True) # Confounders parser.add_argument('-t', '--target_name') parser.add_argument('-c', '--confounder_names', nargs='+') # Resume? parser.add_argument('--resume', default=False, action='store_true') # Label shifts parser.add_argument('--minority_fraction', type=float) parser.add_argument('--imbalance_ratio', type=float) # Data parser.add_argument('--fraction', type=float, default=1.0) parser.add_argument('--root_dir', default=None) parser.add_argument('--subsample_to_minority', action='store_true', default=False) parser.add_argument('--reweight_groups', action='store_true', default=False) parser.add_argument('--augment_data', action='store_true', default=False) parser.add_argument('--val_fraction', type=float, default=0.1) # Objective parser.add_argument('--robust', default=False, action='store_true') parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--generalization_adjustment', default="0.0") parser.add_argument('--automatic_adjustment', default=False, action='store_true') parser.add_argument('--robust_step_size', default=0.01, type=float) parser.add_argument('--use_normalized_loss', default=False, action='store_true') parser.add_argument('--btl', default=False, action='store_true') parser.add_argument('--hinge', default=False, action='store_true') # Model parser.add_argument('--model', choices=model_attributes.keys(), default='resnet50') parser.add_argument('--train_from_scratch', action='store_true', default=False) parser.add_argument('--resnet_width', type=int, default=None) # Optimization parser.add_argument('--n_epochs', type=int, default=4) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--scheduler', action='store_true', default=False) parser.add_argument('--weight_decay', type=float, default=5e-5) parser.add_argument('--gamma', type=float, default=0.1) parser.add_argument('--minimum_variational_weight', type=float, default=0) # Misc parser.add_argument('--seed', type=int, default=0) parser.add_argument('--show_progress', default=False, action='store_true') parser.add_argument('--log_dir', default='./logs') parser.add_argument('--log_every', default=50, type=int) parser.add_argument('--save_step', type=int, default=10) parser.add_argument('--save_best', action='store_true', default=False) parser.add_argument('--save_last', action='store_true', default=False) parser.add_argument('--model_test', type=str) parser.add_argument('--gpu', type=str) args = parser.parse_args() check_args(args) model_test = args.model_test gpu = args.gpu os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = gpu # BERT-specific configs copied over from run_glue.py if args.model == 'bert': args.max_grad_norm = 1.0 args.adam_epsilon = 1e-8 args.warmup_steps = 0 if os.path.exists(args.log_dir) and args.resume: resume = True mode = 'a' else: resume = False mode = 'w' ## Initialize logs if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) logger = Logger(os.path.join(args.log_dir, model_test + '_log.txt'), mode) # Record args log_args(args, logger) set_seed(args.seed) # Data # Test data for label_shift_step is not implemented yet test_data = None test_loader = None if args.shift_type == 'confounder': train_data, val_data, test_data = prepare_data(args, train=True) elif args.shift_type == 'label_shift_step': train_data, val_data = prepare_data(args, train=True) loader_kwargs = { 'batch_size': args.batch_size, 'num_workers': 12, 'pin_memory': True } train_loader = train_data.get_loader(train=True, reweight_groups=args.reweight_groups, **loader_kwargs) val_loader = val_data.get_loader(train=False, reweight_groups=None, **loader_kwargs) if test_data is not None: test_loader = test_data.get_loader(train=False, reweight_groups=None, **loader_kwargs) data = {} data['train_loader'] = train_loader data['val_loader'] = val_loader data['test_loader'] = test_loader data['train_data'] = train_data data['val_data'] = val_data data['test_data'] = test_data n_classes = train_data.n_classes log_data(data, logger) ## Initialize model pretrained = not args.train_from_scratch if resume: model = torch.load(os.path.join(args.log_dir, model_test)) d = train_data.input_size()[0] elif model_attributes[args.model]['feature_type'] in ('precomputed', 'raw_flattened'): assert pretrained # Load precomputed features d = train_data.input_size()[0] model = nn.Linear(d, n_classes) model.has_aux_logits = False elif args.model == 'resnet50': model = torchvision.models.resnet50(pretrained=pretrained) d = model.fc.in_features model.fc = nn.Linear(d, n_classes) elif args.model == 'resnet34': model = torchvision.models.resnet34(pretrained=pretrained) d = model.fc.in_features model.fc = nn.Linear(d, n_classes) elif args.model == 'wideresnet50': model = torchvision.models.wide_resnet50_2(pretrained=pretrained) d = model.fc.in_features model.fc = nn.Linear(d, n_classes) elif args.model == 'resnet50vw': assert not pretrained assert args.resnet_width is not None model = resnet50vw(args.resnet_width, num_classes=n_classes) elif args.model == 'resnet18vw': assert not pretrained assert args.resnet_width is not None model = resnet18vw(args.resnet_width, num_classes=n_classes) elif args.model == 'resnet10vw': assert not pretrained assert args.resnet_width is not None model = resnet10vw(args.resnet_width, num_classes=n_classes) elif args.model == 'bert': assert args.dataset == 'MultiNLI' from pytorch_transformers import BertConfig, BertForSequenceClassification config_class = BertConfig model_class = BertForSequenceClassification config = config_class.from_pretrained('bert-base-uncased', num_labels=3, finetuning_task='mnli') model = model_class.from_pretrained('bert-base-uncased', from_tf=False, config=config) else: raise ValueError('Model not recognized.') logger.flush() ## Define the objective if args.hinge: assert args.dataset in ['CelebA', 'CUB'] # Only supports binary def hinge_loss(yhat, y): # The torch loss takes in three arguments so we need to split yhat # It also expects classes in {+1.0, -1.0} whereas by default we give them in {0, 1} # Furthermore, if y = 1 it expects the first input to be higher instead of the second, # so we need to swap yhat[:, 0] and yhat[:, 1]... torch_loss = torch.nn.MarginRankingLoss(margin=1.0, reduction='none') y = (y.float() * 2.0) - 1.0 return torch_loss(yhat[:, 1], yhat[:, 0], y) criterion = hinge_loss else: criterion = torch.nn.CrossEntropyLoss(reduction='none') if False: df = pd.read_csv(os.path.join(args.log_dir, 'test.csv')) epoch_offset = df.loc[len(df) - 1, 'epoch'] + 1 logger.write(f'starting from epoch {epoch_offset}') else: epoch_offset = 0 train_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'train.csv'), train_data.n_groups, mode=mode) val_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'val.csv'), train_data.n_groups, mode=mode) test_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'test.csv'), train_data.n_groups, mode=mode) train(model, criterion, data, logger, train_csv_logger, val_csv_logger, test_csv_logger, args, epoch_offset=epoch_offset) train_csv_logger.close() val_csv_logger.close() test_csv_logger.close()
def main(): parser = argparse.ArgumentParser() # Settings parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True) parser.add_argument('-s', '--shift_type', choices=shift_types, required=True) # Confounders -- this doesn't really matter because we just care about x parser.add_argument('-t', '--target_name') parser.add_argument('-c', '--confounder_names', nargs='+') parser.add_argument('--root_dir', default=None) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('-m', '--model', choices=model_attributes.keys(), default='resnet18') parser.add_argument('--model_path', default=None) parser.add_argument('--layers_to_extract', type=int, default=1) parser.add_argument('--get_preds_instead_of_features', action='store_true', default=False) args = parser.parse_args() assert args.shift_type == 'confounder' args.augment_data = False if args.root_dir is None: args.root_dir = dataset_attributes[args.dataset]['root_dir'] set_seed(0) full_dataset = prepare_data(args, train=False, return_full_dataset=True) n_classes = full_dataset.n_classes loader_kwargs = { 'batch_size': args.batch_size, 'num_workers': 4, 'pin_memory': True } loader = full_dataset.get_loader(train=False, reweight_groups=None, **loader_kwargs) # Initialize model if not args.get_preds_instead_of_features: if args.model == 'resnet18': model = resnet18(pretrained=True, layers_to_extract=args.layers_to_extract) else: raise ValueError('Model not recognized.') elif args.get_preds_instead_of_features: assert args.model_path.endswith('.pth') model = torch.load(args.model_path) model.eval() model = model.cuda() n = len(full_dataset) idx_check = np.empty(n) last_batch = False start_pos = 0 with torch.set_grad_enabled(False): for i, (x_batch, y, g) in enumerate(tqdm(loader)): x_batch = x_batch.cuda() num_in_batch = list(x_batch.shape)[0] assert num_in_batch <= args.batch_size if num_in_batch < args.batch_size: assert last_batch == False last_batch = True end_pos = start_pos + num_in_batch features_batch = model(x_batch).data.cpu().numpy() if i == 0: d = features_batch.shape[1] print(f'Extracting {d} features per example') features = np.empty((n, d)) features[start_pos:end_pos, :] = features_batch # idx_check[start_pos:end_pos] = idx_batch.data.numpy() start_pos = end_pos if not args.get_preds_instead_of_features: features_dir = os.path.join(args.root_dir, 'features') if not os.path.exists(features_dir): os.makedirs(features_dir) output_path = os.path.join( features_dir, f'{args.model}_{args.layers_to_extract}layer.npy') else: output_path = args.model_path.split( '.pth')[0] + '_preds-on_' + args.dataset + '.npy' np.save(output_path, features)
def main(): parser = argparse.ArgumentParser() # Settings parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True) parser.add_argument('-s', '--shift_type', choices=shift_types, required=True) # Confounders parser.add_argument('-t', '--target_name') parser.add_argument('-c', '--confounder_names', nargs='+') # Resume? parser.add_argument('--resume', default=False, action='store_true') # Label shifts parser.add_argument('--minority_fraction', type=float) parser.add_argument('--imbalance_ratio', type=float) # Data parser.add_argument('--fraction', type=float, default=1.0) parser.add_argument('--root_dir', default=None) parser.add_argument('--subsample_to_minority', action='store_true', default=False) parser.add_argument('--reweight_groups', action='store_true', default=False) parser.add_argument('--augment_data', action='store_true', default=False) parser.add_argument('--val_fraction', type=float, default=0.1) # Objective parser.add_argument('--robust', default=False, action='store_true') parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--generalization_adjustment', default="0.0") parser.add_argument('--automatic_adjustment', default=False, action='store_true') parser.add_argument('--robust_step_size', default=0.01, type=float) parser.add_argument('--use_normalized_loss', default=False, action='store_true') parser.add_argument('--btl', default=False, action='store_true') parser.add_argument('--hinge', default=False, action='store_true') # Model parser.add_argument('--model', choices=model_attributes.keys(), default='resnet50') parser.add_argument('--train_from_scratch', action='store_true', default=False) parser.add_argument('--resnet_width', type=int, default=None) # Optimization parser.add_argument('--n_epochs', type=int, default=4) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--scheduler', action='store_true', default=False) parser.add_argument('--weight_decay', type=float, default=5e-5) parser.add_argument('--gamma', type=float, default=0.1) parser.add_argument('--minimum_variational_weight', type=float, default=0) # Misc parser.add_argument('--seed', type=int, default=0) parser.add_argument('--show_progress', default=False, action='store_true') parser.add_argument('--log_dir', default='./logs') parser.add_argument('--log_every', default=50, type=int) parser.add_argument('--save_step', type=int, default=10) parser.add_argument('--save_best', action='store_true', default=False) parser.add_argument('--save_last', action='store_true', default=True) parser.add_argument('--student_width', type=int) parser.add_argument('--teacher_dir', type=str) parser.add_argument('--teacher_width', type=int) parser.add_argument('--gpu', type=str) parser.add_argument('--temp', type=str) args = parser.parse_args() gpu = args.gpu temp = args.temp check_args(args) teacher_dir = args.teacher_dir student_width = args.student_width teacher_width = args.teacher_width os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = gpu def DistillationLoss(temperature): cross_entropy = torch.nn.CrossEntropyLoss() def loss(student_logits, teacher_logits, target): last_dim = len(student_logits.shape) - 1 p_t = nn.functional.softmax(teacher_logits / temperature, dim=last_dim) log_p_s = nn.functional.log_softmax(student_logits / temperature, dim=last_dim) return cross_entropy(student_logits, target) - (p_t * log_p_s).sum( dim=last_dim).mean() * temperature**2 return loss # BERT-specific configs copied over from run_glue.py if args.model == 'bert': args.max_grad_norm = 1.0 args.adam_epsilon = 1e-8 args.warmup_steps = 0 if os.path.exists(args.log_dir) and args.resume: resume = True mode = 'a' else: resume = False mode = 'w' ## Initialize logs if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode) # Record args log_args(args, logger) set_seed(args.seed) print("starting prep") # Data # Test data for label_shift_step is not implemented yet test_data = None test_loader = None if args.shift_type == 'confounder': train_data, val_data, test_data = prepare_data(args, train=True) elif args.shift_type == 'label_shift_step': train_data, val_data = prepare_data(args, train=True) print("done prep") loader_kwargs = { 'batch_size': args.batch_size, 'num_workers': 16, 'pin_memory': True } train_loader = train_data.get_loader(train=True, reweight_groups=args.reweight_groups, **loader_kwargs) val_loader = val_data.get_loader(train=False, reweight_groups=None, **loader_kwargs) if test_data is not None: test_loader = test_data.get_loader(train=False, reweight_groups=None, **loader_kwargs) data = {} data['train_loader'] = train_loader data['val_loader'] = val_loader data['test_loader'] = test_loader data['train_data'] = train_data data['val_data'] = val_data data['test_data'] = test_data n_classes = train_data.n_classes log_data(data, logger) logger.flush() ## Define the objective if args.hinge: assert args.dataset in ['CelebA', 'CUB'] # Only supports binary def hinge_loss(yhat, y): # The torch loss takes in three arguments so we need to split yhat # It also expects classes in {+1.0, -1.0} whereas by default we give them in {0, 1} # Furthermore, if y = 1 it expects the first input to be higher instead of the second, # so we need to swap yhat[:, 0] and yhat[:, 1]... torch_loss = torch.nn.MarginRankingLoss(margin=1.0, reduction='none') y = (y.float() * 2.0) - 1.0 return torch_loss(yhat[:, 1], yhat[:, 0], y) criterion = hinge_loss else: criterion = torch.nn.CrossEntropyLoss(reduction='none') if resume: df = pd.read_csv(os.path.join(args.log_dir, 'test.csv')) epoch_offset = df.loc[len(df) - 1, 'epoch'] + 1 logger.write(f'starting from epoch {epoch_offset}') else: epoch_offset = 0 train_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'train.csv'), train_data.n_groups, mode=mode) val_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'val.csv'), train_data.n_groups, mode=mode) test_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'test.csv'), train_data.n_groups, mode=mode) strain_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'strain.csv'), train_data.n_groups, mode=mode) sval_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'sval.csv'), train_data.n_groups, mode=mode) stest_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'stest.csv'), train_data.n_groups, mode=mode) teacher = resnet10vw(teacher_width, num_classes=n_classes) teacher_old = torch.load(teacher_dir + "/10_model.pth") for k, m in teacher_old.named_modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability teacher.load_state_dict(teacher_old.state_dict()) teacher = teacher.to('cuda') # def DistillationLoss(temperature): # cross_entropy = torch.nn.CrossEntropyLoss() # # def loss(student_logits, teacher_logits, target): # last_dim = len(student_logits.shape) - 1 # # p_t = nn.functional.softmax(teacher_logits/temperature, dim=last_dim) # log_p_s = nn.functional.log_softmax(student_logits/temperature, dim=last_dim) # # return cross_entropy(student_logits, target) - (p_t * log_p_s).sum(dim=last_dim).mean() # # return loss distill_criterion = DistillationLoss(float(temp)) student = resnet10vw(int(student_width), num_classes=n_classes).to('cuda') #student.to(device) train(teacher, student, criterion, distill_criterion, data, logger, train_csv_logger, val_csv_logger, test_csv_logger, strain_csv_logger, sval_csv_logger, test_csv_logger, args, epoch_offset=epoch_offset) train_csv_logger.close() val_csv_logger.close() test_csv_logger.close() strain_csv_logger.close() sval_csv_logger.close() stest_csv_logger.close()
import tensorflow as tf from model import model from data import data def parse(): parser = argparse.ArgumentParser() parser.add_argument('--action', required=True) parser.add_argument('--datatype') parser.add_argument('--load', action='store_true') args = parser.parse_args() return args args = parse() if args.action == 'data': if args.datatype != 'gigaword' and args.datatype != 'reuters' and args.datatype != 'cnn': print('Invalid data type.') else: data = data(args) data.prepare_data(args) else: sess = tf.Session() model = model(sess, args) if(args.action == 'pretrain'): model.pretrain() elif(args.action == 'train'): model.train() elif(args.action == 'test'): model.test() elif(args.action == 'save'): model.save()
def main(args): if args.wandb: wandb.init(project=f"{args.project_name}_{args.dataset}") wandb.config.update(args) # BERT-specific configs copied over from run_glue.py if (args.model.startswith("bert") and args.use_bert_params): args.max_grad_norm = 1.0 args.adam_epsilon = 1e-8 args.warmup_steps = 0 if os.path.exists(args.log_dir) and args.resume: resume = True mode = "a" else: resume = False mode = "w" ## Initialize logs if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) logger = Logger(os.path.join(args.log_dir, "log.txt"), mode) # Record args log_args(args, logger) set_seed(args.seed) # Data # Test data for label_shift_step is not implemented yet test_data = None test_loader = None if args.shift_type == "confounder": train_data, val_data, test_data = prepare_data( args, train=True, ) elif args.shift_type == "label_shift_step": raise NotImplementedError train_data, val_data = prepare_data(args, train=True) ######################################################################### ###################### Prepare data for our method ###################### ######################################################################### # Should probably not be upweighting if folds are specified. assert not args.fold or not args.up_weight # Fold passed. Use it as train and valid. if args.fold: train_data, val_data = folds.get_fold( train_data, args.fold, cross_validation_ratio=(1 / args.num_folds_per_sweep), num_valid_per_point=args.num_sweeps, seed=args.seed, ) if args.up_weight != 0: assert args.aug_col is not None # Get points that should be upsampled metadata_df = pd.read_csv(args.metadata_path) if args.dataset == "jigsaw": train_col = metadata_df[metadata_df["split"] == "train"] else: train_col = metadata_df[metadata_df["split"] == 0] aug_indices = np.where(train_col[args.aug_col] == 1)[0] print("len", len(train_col), len(aug_indices)) if args.up_weight == -1: up_weight_factor = int( (len(train_col) - len(aug_indices)) / len(aug_indices)) - 1 else: up_weight_factor = args.up_weight print(f"Up-weight factor: {up_weight_factor}") upsampled_points = Subset(train_data, list(aug_indices) * up_weight_factor) # Convert to DRODataset train_data = dro_dataset.DRODataset( ConcatDataset([train_data, upsampled_points]), process_item_fn=None, n_groups=train_data.n_groups, n_classes=train_data.n_classes, group_str_fn=train_data.group_str, ) elif args.aug_col is not None: print("\n"*2 + "WARNING: aug_col is not being used." + "\n"*2) ######################################################################### ######################################################################### ######################################################################### loader_kwargs = { "batch_size": args.batch_size, "num_workers": 4, "pin_memory": True, } train_loader = dro_dataset.get_loader(train_data, train=True, reweight_groups=args.reweight_groups, **loader_kwargs) val_loader = dro_dataset.get_loader(val_data, train=False, reweight_groups=None, **loader_kwargs) if test_data is not None: test_loader = dro_dataset.get_loader(test_data, train=False, reweight_groups=None, **loader_kwargs) data = {} data["train_loader"] = train_loader data["val_loader"] = val_loader data["test_loader"] = test_loader data["train_data"] = train_data data["val_data"] = val_data data["test_data"] = test_data n_classes = train_data.n_classes log_data(data, logger) ## Initialize model model = get_model( model=args.model, pretrained=not args.train_from_scratch, resume=resume, n_classes=train_data.n_classes, dataset=args.dataset, log_dir=args.log_dir, ) if args.wandb: wandb.watch(model) logger.flush() ## Define the objective if args.hinge: assert args.dataset in ["CelebA", "CUB"] # Only supports binary criterion = hinge_loss else: criterion = torch.nn.CrossEntropyLoss(reduction="none") if resume: raise NotImplementedError # Check this implementation. df = pd.read_csv(os.path.join(args.log_dir, "test.csv")) epoch_offset = df.loc[len(df) - 1, "epoch"] + 1 logger.write(f"starting from epoch {epoch_offset}") else: epoch_offset = 0 train_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, f"train.csv"), train_data.n_groups, mode=mode) val_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, f"val.csv"), val_data.n_groups, mode=mode) test_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, f"test.csv"), test_data.n_groups, mode=mode) train( model, criterion, data, logger, train_csv_logger, val_csv_logger, test_csv_logger, args, epoch_offset=epoch_offset, csv_name=args.fold, wandb=wandb if args.wandb else None, ) train_csv_logger.close() val_csv_logger.close() test_csv_logger.close()
def main(): print("Loading and checking args...") args = parse_args() check_args(args) # BERT-specific configs copied over from run_glue.py if args.model.startswith('bert'): args.max_grad_norm = 1.0 args.adam_epsilon = 1e-8 args.warmup_steps = 0 #Write for logging; assumes no existing logs. mode = 'w' ## Initialize logs if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode) # Record args log_args(args, logger) set_seed(args.seed) # Data print("Preparing data") train_data, val_data, test_data = prepare_data(args, train=True) print("Setting up loader") loader_kwargs = { 'batch_size': args.batch_size, 'num_workers': 4, 'pin_memory': True } train_loader = train_data.get_loader(train=True, reweight_groups=args.reweight_groups, **loader_kwargs) val_loader = val_data.get_loader(train=False, reweight_groups=None, **loader_kwargs) test_loader = test_data.get_loader(train=False, reweight_groups=None, **loader_kwargs) data = {} data['train_loader'] = train_loader data['val_loader'] = val_loader data['test_loader'] = test_loader data['train_data'] = train_data data['val_data'] = val_data data['test_data'] = test_data n_classes = train_data.n_classes log_data(data, logger) ## Initialize model if args.model == 'resnet50': model = torchvision.models.resnet50(pretrained=True) d = model.fc.in_features model.fc = nn.Linear(d, n_classes) if args.mc_dropout: model = add_dropout(model, 'fc') elif args.model == 'densenet121': model = torchvision.models.densenet121(pretrained=True) d = model.classifier.in_features model.classifier = nn.Linear(d, n_classes) if args.mc_dropout: model = add_dropout(model, 'classifier') elif args.model == 'bert-base-uncased': print("Loading bert") model = BertForSequenceClassification.from_pretrained( args.model, num_labels=n_classes) else: raise ValueError('Model not recognized.') logger.flush() criterion = torch.nn.CrossEntropyLoss(reduction='none') print("Getting loggers") train_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'train.csv'), train_data.n_groups, mode=mode) val_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'val.csv'), train_data.n_groups, mode=mode) test_csv_logger = CSVBatchLogger(os.path.join(args.log_dir, 'test.csv'), train_data.n_groups, mode=mode) print("Starting to train...") train(model, criterion, data, logger, train_csv_logger, val_csv_logger, test_csv_logger, args, epoch_offset=0, train=True) train_csv_logger.close() val_csv_logger.close() test_csv_logger.close() if args.save_preds: save_preds(model, data, args) return