Example #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_loc', default='./cv_simplified')
    parser.add_argument('--checkpoint_loading_path', default='./saved_model')
    parser.add_argument('--checkpoint_saving_path', default='./saved_model')
    parser.add_argument('--on_cluster', action='store_true', default=False)
    parser.add_argument('--resume', action='store_true', default=False)
    parser.add_argument('--line_search', action='store_true', default=False)
    parser.add_argument('--wandb_logging', action='store_true', default=False)
    parser.add_argument('--wandb_exp_name')
    parser.add_argument('--step_size', type=float, default=1e-06)
    parser.add_argument('--batch_size', type=int, default=500)
    parser.add_argument('--regularization_param', type=float, default=1e-03)
    parser.add_argument('--max_pass', type=int, default=10)
    parser.add_argument('--opt_method', default='adam')
    args = parser.parse_args()

    if args.on_cluster:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    if args.wandb_logging:
        wandb.init(project='quick_draw_crf', name=args.wandb_exp_name)
        wandb.config.update({"step_size": args.step_size, "l2_regularize": args.regularization_param,
                             "opt_method":'sag'})

    torch.manual_seed(1)
    num_features = 851968
    num_cats = 113
    num_val_data_per_cat = 500

    data_fh = h5py.File(args.dataset_loc, 'r')
    num_tr_data = len(data_fh['tr_data'])
    tr_dataset = Image_dataset(data=data_fh['tr_data'], label=data_fh['tr_label'], data_size=num_tr_data, num_cats=num_cats, device=device)
    val_dataset = Validation_dataset(data=data_fh['val_data'], label=data_fh['val_label'], num_cats=num_cats, num_data_per_cat=num_val_data_per_cat, num_features=num_features, device=device)
    train_loader = data.DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True)
    validate_loader = data.DataLoader(val_dataset, batch_size=args.batch_size)
    validate_data_true_label = parse_validation_data_labels(data_fh['val_label'][:]) # Assumed no zero entries here!

    crf = CRF(num_cats, num_features, num_tr_data, args.step_size, args.regularization_param, args.line_search, device)

    iter = 0
    for num_pass in range(args.max_pass):
        for image_batch, label_vecs, labels, indicies in train_loader:
            num_data_in_batch = image_batch.shape[0]
            for i in range(num_data_in_batch):
                crf.update(image_batch[i], label_vecs[i], labels[i], indicies[i])
                iter += 1
                if iter % 1000 == 0:
                    val_err = validate(crf, validate_data_true_label, validate_loader, val_dataset.data_size, device)
                    if args.wandb_logging:
                        grad = torch.abs(crf.full_grad/crf.num_sample_visited)
                        step_size = 1/crf.line_search_lr if args.line_search else crf.lr
                        wandb.log({'val_err': val_err, 'grad_l1': torch.sum(torch.sum(grad)), 'grad_l_inf':torch.max(grad), 'step_size':step_size})
Example #2
0
def validate(resnet, validate_data_true_label, validate_loader,
             val_dataset_size, device):
    all_predictions = []
    resnet.eval()
    for image_batch, _ in validate_loader:
        image_batch = image_batch.to(device)
        outputs = resnet(image_batch)
        _, predicted = torch.topk(outputs.data, 3)
        all_predictions.extend(predicted.cpu().numpy())
    all_predictions = np.squeeze(np.array(all_predictions))
    return mapk(actual=validate_data_true_label,
                predicted=parse_validation_data_labels(all_predictions),
                k=3)
Example #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_loc', default='./cv_simplified')
    parser.add_argument('--checkpoint_loading_path', default='./saved_model')
    parser.add_argument('--checkpoint_saving_path', default='./saved_model')
    parser.add_argument('--on_cluster', action='store_true', default=False)
    parser.add_argument('--resume', action='store_true', default=False)
    parser.add_argument('--wandb_logging', action='store_true', default=False)
    parser.add_argument('--wandb_exp_name')
    parser.add_argument('--step_size', type=float, default=1e-06)
    parser.add_argument('--weight_decay', type=float, default=1e-06)
    parser.add_argument('--batch_size', type=int, default=500)
    parser.add_argument('--regularization_param', type=float, default=1e-03)
    parser.add_argument('--max_epoch', type=int, default=100)
    parser.add_argument('--opt_method', default='adam')
    args = parser.parse_args()

    if args.on_cluster:
        dataset_loc = os.environ.get('SLURM_TMPDIR', '.')
        dataset_loc = os.path.join(dataset_loc, args.dataset_loc)
        device = torch.device('cuda')
    else:
        dataset_loc = args.dataset_loc
        device = torch.device('cpu')

    if args.wandb_logging:
        wandb.init(project='quick_draw_crf', name=args.wandb_exp_name)
        wandb.config.update({
            "step_size": args.step_size,
            "weight_decay": args.weight_decay,
            "opt_method": args.opt_method
        })

    torch.manual_seed(1)
    num_cats = 5
    num_features = 851968
    tr_dataset = Image_dataset(data_path=os.path.join(dataset_loc, 'tr.pkl'),
                               num_features=num_features,
                               device=device)
    val_dataset = Image_dataset(data_path=os.path.join(dataset_loc, 'val.pkl'),
                                num_features=num_features,
                                device=device)
    train_loader = data.DataLoader(tr_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True)
    validate_loader = data.DataLoader(val_dataset, batch_size=args.batch_size)
    validate_data_true_label = parse_validation_data_labels(
        pd.read_pickle(os.path.join(args.dataset_loc,
                                    'val.pkl'))['cat_ind'].values)

    weight_param = torch.randn((num_cats, num_features),
                               dtype=torch.float64,
                               device=device,
                               requires_grad=True)
    crf = CRF(weight_param, num_features, num_cats).to(device)
    if args.opt_method == 'adam':
        optimizer = optim.Adam(crf.parameters(),
                               lr=args.step_size,
                               weight_decay=args.weight_decay)
    elif args.opt_method == 'sgd':
        optimizer = optim.SGD(crf.parameters(),
                              lr=args.step_size,
                              weight_decay=args.weight_decay)
    crf.name = 'crf_' + args.opt_method

    print('initial validation error = %.3f' %
          (validate(crf, validate_data_true_label, validate_loader,
                    val_dataset.data_size, device)))

    for epoch in range(args.max_epoch):
        optimizer.zero_grad()
        i = 1
        for image_batch, data_ids, data_labels in train_loader:
            loss = crf(image_batch, data_labels, args.regularization_param,
                       device)
            loss.backward()
            if i % 10 == 0:
                optimizer.step()
                adam_grad_l1 = torch.sum(
                    torch.sum(torch.abs(crf.linear.weight.grad)))
                optimizer.zero_grad()
            i += 1
        val_err = validate(crf, validate_data_true_label, validate_loader,
                           val_dataset.data_size, device)
        if args.wandb_logging:
            wandb.log({
                'val_err': val_err,
                'tr_loss': loss,
                'adam_grad_l1': adam_grad_l1
            })
        else:
            print('epoch=%d. ' % (epoch) + '. validation error = %.3f' %
                  (val_err))

    save_checkpoint(crf, args.checkpoint_saving_path, epoch)
Example #4
0
def validate(crf, validate_data_true_label, validate_loader, val_dataset_size, device):
    all_predictions = []
    for image_batch, data_ids, _ in validate_loader:
        all_predictions.append(crf.make_predictions(image_batch).cpu().numpy())
    all_predictions = np.squeeze(np.array(all_predictions))
    return mapk(actual=validate_data_true_label, predicted=parse_validation_data_labels(all_predictions), k=3)
Example #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_loc', default='./cv_simplified')
    parser.add_argument('--model_loading_path', default='./saved_model')
    parser.add_argument('--model_saving_path', default='./saved_model')
    parser.add_argument('--load_model', action='store_true', default=False)
    parser.add_argument('--on_cluster', action='store_true', default=False)
    parser.add_argument('--resume', action='store_true', default=False)
    parser.add_argument('--line_search', action='store_true', default=False)
    parser.add_argument('--wandb_logging', action='store_true', default=False)
    parser.add_argument('--wandb_exp_name')
    parser.add_argument('--step_size', type=float, default=1e-06)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--batch_size', type=int, default=500)
    parser.add_argument('--num_iter_accum_grad', type=int, default=1)
    parser.add_argument('--max_pass', type=int, default=10)
    parser.add_argument('--opt_method', default='adam')
    parser.add_argument('--model_name', default='resnet18')
    args = parser.parse_args()

    device = torch.device('cuda') if args.on_cluster else torch.device('cpu')
    if args.on_cluster:
        os.environ['TORCH_HOME'] = os.environ.get('SLURM_TMPDIR', '.')

    if args.wandb_logging:
        wandb.init(project='quick_draw_crf', name=args.wandb_exp_name)
        wandb.config.update({
            "step_size": args.step_size,
            "opt_method": args.opt_method,
            "num_data_used_calc_grad":
            args.batch_size * args.num_iter_accum_grad,
            'model_name': args.model_name
        })

    num_cats = 340
    num_tr_files = 20
    val_data_fh = h5py.File(
        os.path.join(args.dataset_loc, 'quick_draw_resnet_val_data.hdf5'), 'r')
    val_dataset = Image_dataset(data=val_data_fh['val_data'],
                                label=val_data_fh['val_label'],
                                data_size=len(val_data_fh['val_data']))
    validate_loader = data.DataLoader(val_dataset, batch_size=args.batch_size)
    validate_data_true_label = parse_validation_data_labels(
        val_data_fh['val_label'][:])  # Assumed no zero entries here!

    if args.model_name == 'resnet18':
        resnet = torchvision.models.resnet18(pretrained=False,
                                             num_classes=num_cats).to(device)
    elif args.model_name == 'resnet34':
        resnet = torchvision.models.resnet34(pretrained=False,
                                             num_classes=num_cats).to(device)
    elif args.model_name == 'resnet50':
        resnet = torchvision.models.resnet50(pretrained=False,
                                             num_classes=num_cats).to(device)
    optimizer = optim.Adam(resnet.parameters(),
                           lr=args.step_size,
                           weight_decay=args.weight_decay)

    if args.load_model:
        checkpoint = torch.load(args.model_loading_path)
        resnet.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        resnet.train()
    else:
        epoch = 0
    criterion = nn.CrossEntropyLoss()
    #optimizer = optim.SGD(resnet.parameters(), lr=args.step_size, momentum=0.9)
    while epoch < args.max_pass:
        optimizer.zero_grad()
        i = 1

        tr_data_fh = h5py.File(
            os.path.join(
                args.dataset_loc, 'quick_draw_resnet_data_' +
                str(np.random.randint(0, num_tr_files)) + '.hdf5'), 'r')
        tr_dataset = Image_dataset(data=tr_data_fh['tr_data'],
                                   label=tr_data_fh['tr_label'],
                                   data_size=len(tr_data_fh['tr_data']))
        train_loader = data.DataLoader(tr_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True)

        for image_batch, data_labels in train_loader:
            image_batch = image_batch.to(device)
            data_labels = data_labels.to(device)
            resnet.train()
            outputs = resnet(image_batch)
            loss = criterion(outputs, data_labels)
            loss.backward()
            if i % args.num_iter_accum_grad == 0:
                optimizer.step()
                optimizer.zero_grad()

                val_err = validate(resnet, validate_data_true_label,
                                   validate_loader, val_dataset.data_size,
                                   device)
                if args.wandb_logging:
                    wandb.log({'val_err': val_err, 'tr_loss': loss.item()})
                else:
                    print('epoch=%d. ' % (epoch) +
                          '. validation error = %.3f' % (val_err))
            i += 1
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': resnet.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, os.path.join(args.model_saving_path, args.model_name + '.pt'))
        epoch += 1