def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset_loc', default='./cv_simplified') parser.add_argument('--checkpoint_loading_path', default='./saved_model') parser.add_argument('--checkpoint_saving_path', default='./saved_model') parser.add_argument('--on_cluster', action='store_true', default=False) parser.add_argument('--resume', action='store_true', default=False) parser.add_argument('--line_search', action='store_true', default=False) parser.add_argument('--wandb_logging', action='store_true', default=False) parser.add_argument('--wandb_exp_name') parser.add_argument('--step_size', type=float, default=1e-06) parser.add_argument('--batch_size', type=int, default=500) parser.add_argument('--regularization_param', type=float, default=1e-03) parser.add_argument('--max_pass', type=int, default=10) parser.add_argument('--opt_method', default='adam') args = parser.parse_args() if args.on_cluster: device = torch.device('cuda') else: device = torch.device('cpu') if args.wandb_logging: wandb.init(project='quick_draw_crf', name=args.wandb_exp_name) wandb.config.update({"step_size": args.step_size, "l2_regularize": args.regularization_param, "opt_method":'sag'}) torch.manual_seed(1) num_features = 851968 num_cats = 113 num_val_data_per_cat = 500 data_fh = h5py.File(args.dataset_loc, 'r') num_tr_data = len(data_fh['tr_data']) tr_dataset = Image_dataset(data=data_fh['tr_data'], label=data_fh['tr_label'], data_size=num_tr_data, num_cats=num_cats, device=device) val_dataset = Validation_dataset(data=data_fh['val_data'], label=data_fh['val_label'], num_cats=num_cats, num_data_per_cat=num_val_data_per_cat, num_features=num_features, device=device) train_loader = data.DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True) validate_loader = data.DataLoader(val_dataset, batch_size=args.batch_size) validate_data_true_label = parse_validation_data_labels(data_fh['val_label'][:]) # Assumed no zero entries here! crf = CRF(num_cats, num_features, num_tr_data, args.step_size, args.regularization_param, args.line_search, device) iter = 0 for num_pass in range(args.max_pass): for image_batch, label_vecs, labels, indicies in train_loader: num_data_in_batch = image_batch.shape[0] for i in range(num_data_in_batch): crf.update(image_batch[i], label_vecs[i], labels[i], indicies[i]) iter += 1 if iter % 1000 == 0: val_err = validate(crf, validate_data_true_label, validate_loader, val_dataset.data_size, device) if args.wandb_logging: grad = torch.abs(crf.full_grad/crf.num_sample_visited) step_size = 1/crf.line_search_lr if args.line_search else crf.lr wandb.log({'val_err': val_err, 'grad_l1': torch.sum(torch.sum(grad)), 'grad_l_inf':torch.max(grad), 'step_size':step_size})
def validate(resnet, validate_data_true_label, validate_loader, val_dataset_size, device): all_predictions = [] resnet.eval() for image_batch, _ in validate_loader: image_batch = image_batch.to(device) outputs = resnet(image_batch) _, predicted = torch.topk(outputs.data, 3) all_predictions.extend(predicted.cpu().numpy()) all_predictions = np.squeeze(np.array(all_predictions)) return mapk(actual=validate_data_true_label, predicted=parse_validation_data_labels(all_predictions), k=3)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset_loc', default='./cv_simplified') parser.add_argument('--checkpoint_loading_path', default='./saved_model') parser.add_argument('--checkpoint_saving_path', default='./saved_model') parser.add_argument('--on_cluster', action='store_true', default=False) parser.add_argument('--resume', action='store_true', default=False) parser.add_argument('--wandb_logging', action='store_true', default=False) parser.add_argument('--wandb_exp_name') parser.add_argument('--step_size', type=float, default=1e-06) parser.add_argument('--weight_decay', type=float, default=1e-06) parser.add_argument('--batch_size', type=int, default=500) parser.add_argument('--regularization_param', type=float, default=1e-03) parser.add_argument('--max_epoch', type=int, default=100) parser.add_argument('--opt_method', default='adam') args = parser.parse_args() if args.on_cluster: dataset_loc = os.environ.get('SLURM_TMPDIR', '.') dataset_loc = os.path.join(dataset_loc, args.dataset_loc) device = torch.device('cuda') else: dataset_loc = args.dataset_loc device = torch.device('cpu') if args.wandb_logging: wandb.init(project='quick_draw_crf', name=args.wandb_exp_name) wandb.config.update({ "step_size": args.step_size, "weight_decay": args.weight_decay, "opt_method": args.opt_method }) torch.manual_seed(1) num_cats = 5 num_features = 851968 tr_dataset = Image_dataset(data_path=os.path.join(dataset_loc, 'tr.pkl'), num_features=num_features, device=device) val_dataset = Image_dataset(data_path=os.path.join(dataset_loc, 'val.pkl'), num_features=num_features, device=device) train_loader = data.DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True) validate_loader = data.DataLoader(val_dataset, batch_size=args.batch_size) validate_data_true_label = parse_validation_data_labels( pd.read_pickle(os.path.join(args.dataset_loc, 'val.pkl'))['cat_ind'].values) weight_param = torch.randn((num_cats, num_features), dtype=torch.float64, device=device, requires_grad=True) crf = CRF(weight_param, num_features, num_cats).to(device) if args.opt_method == 'adam': optimizer = optim.Adam(crf.parameters(), lr=args.step_size, weight_decay=args.weight_decay) elif args.opt_method == 'sgd': optimizer = optim.SGD(crf.parameters(), lr=args.step_size, weight_decay=args.weight_decay) crf.name = 'crf_' + args.opt_method print('initial validation error = %.3f' % (validate(crf, validate_data_true_label, validate_loader, val_dataset.data_size, device))) for epoch in range(args.max_epoch): optimizer.zero_grad() i = 1 for image_batch, data_ids, data_labels in train_loader: loss = crf(image_batch, data_labels, args.regularization_param, device) loss.backward() if i % 10 == 0: optimizer.step() adam_grad_l1 = torch.sum( torch.sum(torch.abs(crf.linear.weight.grad))) optimizer.zero_grad() i += 1 val_err = validate(crf, validate_data_true_label, validate_loader, val_dataset.data_size, device) if args.wandb_logging: wandb.log({ 'val_err': val_err, 'tr_loss': loss, 'adam_grad_l1': adam_grad_l1 }) else: print('epoch=%d. ' % (epoch) + '. validation error = %.3f' % (val_err)) save_checkpoint(crf, args.checkpoint_saving_path, epoch)
def validate(crf, validate_data_true_label, validate_loader, val_dataset_size, device): all_predictions = [] for image_batch, data_ids, _ in validate_loader: all_predictions.append(crf.make_predictions(image_batch).cpu().numpy()) all_predictions = np.squeeze(np.array(all_predictions)) return mapk(actual=validate_data_true_label, predicted=parse_validation_data_labels(all_predictions), k=3)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset_loc', default='./cv_simplified') parser.add_argument('--model_loading_path', default='./saved_model') parser.add_argument('--model_saving_path', default='./saved_model') parser.add_argument('--load_model', action='store_true', default=False) parser.add_argument('--on_cluster', action='store_true', default=False) parser.add_argument('--resume', action='store_true', default=False) parser.add_argument('--line_search', action='store_true', default=False) parser.add_argument('--wandb_logging', action='store_true', default=False) parser.add_argument('--wandb_exp_name') parser.add_argument('--step_size', type=float, default=1e-06) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--batch_size', type=int, default=500) parser.add_argument('--num_iter_accum_grad', type=int, default=1) parser.add_argument('--max_pass', type=int, default=10) parser.add_argument('--opt_method', default='adam') parser.add_argument('--model_name', default='resnet18') args = parser.parse_args() device = torch.device('cuda') if args.on_cluster else torch.device('cpu') if args.on_cluster: os.environ['TORCH_HOME'] = os.environ.get('SLURM_TMPDIR', '.') if args.wandb_logging: wandb.init(project='quick_draw_crf', name=args.wandb_exp_name) wandb.config.update({ "step_size": args.step_size, "opt_method": args.opt_method, "num_data_used_calc_grad": args.batch_size * args.num_iter_accum_grad, 'model_name': args.model_name }) num_cats = 340 num_tr_files = 20 val_data_fh = h5py.File( os.path.join(args.dataset_loc, 'quick_draw_resnet_val_data.hdf5'), 'r') val_dataset = Image_dataset(data=val_data_fh['val_data'], label=val_data_fh['val_label'], data_size=len(val_data_fh['val_data'])) validate_loader = data.DataLoader(val_dataset, batch_size=args.batch_size) validate_data_true_label = parse_validation_data_labels( val_data_fh['val_label'][:]) # Assumed no zero entries here! if args.model_name == 'resnet18': resnet = torchvision.models.resnet18(pretrained=False, num_classes=num_cats).to(device) elif args.model_name == 'resnet34': resnet = torchvision.models.resnet34(pretrained=False, num_classes=num_cats).to(device) elif args.model_name == 'resnet50': resnet = torchvision.models.resnet50(pretrained=False, num_classes=num_cats).to(device) optimizer = optim.Adam(resnet.parameters(), lr=args.step_size, weight_decay=args.weight_decay) if args.load_model: checkpoint = torch.load(args.model_loading_path) resnet.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] resnet.train() else: epoch = 0 criterion = nn.CrossEntropyLoss() #optimizer = optim.SGD(resnet.parameters(), lr=args.step_size, momentum=0.9) while epoch < args.max_pass: optimizer.zero_grad() i = 1 tr_data_fh = h5py.File( os.path.join( args.dataset_loc, 'quick_draw_resnet_data_' + str(np.random.randint(0, num_tr_files)) + '.hdf5'), 'r') tr_dataset = Image_dataset(data=tr_data_fh['tr_data'], label=tr_data_fh['tr_label'], data_size=len(tr_data_fh['tr_data'])) train_loader = data.DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True) for image_batch, data_labels in train_loader: image_batch = image_batch.to(device) data_labels = data_labels.to(device) resnet.train() outputs = resnet(image_batch) loss = criterion(outputs, data_labels) loss.backward() if i % args.num_iter_accum_grad == 0: optimizer.step() optimizer.zero_grad() val_err = validate(resnet, validate_data_true_label, validate_loader, val_dataset.data_size, device) if args.wandb_logging: wandb.log({'val_err': val_err, 'tr_loss': loss.item()}) else: print('epoch=%d. ' % (epoch) + '. validation error = %.3f' % (val_err)) i += 1 torch.save( { 'epoch': epoch, 'model_state_dict': resnet.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, os.path.join(args.model_saving_path, args.model_name + '.pt')) epoch += 1