shuffle=True)  #, num_workers=int(opt.workers))

print(len(dataset))
num_classes = len(dataset.classes)
print('classes', num_classes)

try:
    os.makedirs(opt.outf)
except OSError:
    pass

classifier = PointNetCls(k=num_classes,
                         feature_transform=opt.feature_transform)
if opt.model != '':
    classifier.load_state_dict(torch.load(opt.model))
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
classifier.cuda()
num_batch = len(dataset) / opt.batchSize

best_val = 0
start_time = time.time()

for epoch in range(opt.nepoch):
    scheduler.step()

    train_correct = 0
    total_trainset = 0

    for i, data in enumerate(dataloader, 0):
        points, target = data
def main(args):
    blue = lambda x: '\033[94m' + x + '\033[0m'

    seeding(args.seed)

    if args.hfta:
        B = consolidate_hyperparams_and_determine_B(
            args,
            ['lr', 'beta1', 'beta2', 'weight_decay', 'gamma', 'step_size'],
        )
    else:
        B = 0
        (args.lr, args.beta1, args.beta2, args.weight_decay, args.gamma,
         args.step_size) = (args.lr[0], args.beta1[0], args.beta2[0],
                            args.weight_decay[0], args.gamma[0],
                            args.step_size[0])

    if args.device == 'cuda':
        assert torch.cuda.is_available()
        torch.backends.cudnn.benchmark = True
        print('Enable cuDNN heuristics!')
    device = (xm.xla_device()
              if args.device == 'xla' else torch.device(args.device))

    dataset, test_dataset = build_dataset(args)
    dataloader, testdataloader = build_dataloader(args, dataset, test_dataset)

    print('len(dataset)={}'.format(len(dataset)),
          'len(test_dataset)={}'.format(len(test_dataset)))
    num_classes = len(dataset.classes)
    print('classes', num_classes)

    if args.outf is not None:
        try:
            os.makedirs(args.outf)
        except OSError:
            pass

    classifier = PointNetCls(
        k=num_classes,
        feature_transform=args.feature_transform,
        B=B,
        track_running_stats=(args.device != 'xla'),
    )

    if args.model != '':
        classifier.load_state_dict(torch.load(args.model))

    optimizer = get_hfta_optim_for(optim.Adam, B=B)(
        classifier.parameters(),
        lr=args.lr,
        betas=(args.beta1, args.beta2),
        weight_decay=args.weight_decay,
    )
    scheduler = get_hfta_lr_scheduler_for(optim.lr_scheduler.StepLR, B=B)(
        optimizer,
        step_size=args.step_size,
        gamma=args.gamma,
    )

    scaler = amp.GradScaler(enabled=(args.device == 'cuda' and args.amp))

    classifier.to(device)

    num_batch = len(dataloader)

    def loss_fn(output, label, batch_size, trans_feat):
        if B > 0:
            loss = B * F.nll_loss(output.view(B * batch_size, -1), label)
        else:
            loss = F.nll_loss(output, label)
        if args.feature_transform:
            loss += feature_transform_regularizer(trans_feat) * 0.001
        return loss

    classifier = classifier.train()
    epoch_timer = EpochTimer()

    # Training loop
    for epoch in range(args.epochs):
        num_samples_per_epoch = 0
        epoch_timer.epoch_start(epoch)
        for i, data in enumerate(dataloader, 0):
            if i > args.iters_per_epoch:
                break
            if args.warmup_data_loading:
                continue

            points, target = data
            target = target[:, 0]
            points, target = points.to(device), target.to(device)
            N = points.size(0)
            if B > 0:
                points = points.unsqueeze(0).expand(B, -1, -1, -1).contiguous()
                target = target.repeat(B)
            optimizer.zero_grad(set_to_none=True)
            if args.device == 'cuda':
                with amp.autocast(enabled=args.amp):
                    pred, trans, trans_feat = classifier(points)
                    loss = loss_fn(pred, target, N, trans_feat)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
            else:
                pred, trans, trans_feat = classifier(points)
                loss = loss_fn(pred, target, N, trans_feat)
                loss.backward()
                if args.device == 'xla':
                    xm.optimizer_step(optimizer, barrier=True)
                else:
                    optimizer.step()

            print('[{}: {}/{}] train loss: {}'.format(epoch, i, num_batch,
                                                      loss.item()))
            num_samples_per_epoch += N * max(B, 1)
            scaler.update()
        scheduler.step()
        epoch_timer.epoch_stop(num_samples_per_epoch)
        print('Epoch {} took {} s!'.format(epoch,
                                           epoch_timer.epoch_latency(epoch)))

    if args.device == 'xla' and not args.eval:
        print(met.metrics_report())
    if args.outf is not None:
        epoch_timer.to_csv(args.outf)

    if args.eval:
        # Run validation loop.
        print("Running validation loop ...")
        classifier = classifier.eval()
        with torch.no_grad():
            total_correct = torch.zeros(max(B, 1), device=device)
            total_testset = 0
            for data in testdataloader:
                if args.warmup_data_loading:
                    continue
                points, target = data
                target = target[:, 0]
                points, target = points.to(device), target.to(device)
                N = points.size(0)
                if B > 0:
                    points = points.unsqueeze(0).expand(B, -1, -1,
                                                        -1).contiguous()
                    target = target.repeat(B)
                pred, _, _ = classifier(points)
                pred_choice = pred.argmax(-1)

                correct = pred_choice.eq(
                    target.view(B, N) if B > 0 else target).sum(-1)

                total_correct.add_(correct)
                total_testset += N

            final_accuracy = total_correct / total_testset
            final_accuracy = final_accuracy.cpu().tolist()
            if args.outf is not None:
                pd.DataFrame({
                    'acc': final_accuracy
                }).to_csv(os.path.join(args.outf, 'eval.csv'))

            # Return test_accuracy
            return final_accuracy