def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100

    trainset = dataloader(root='./data',
                          train=True,
                          download=True,
                          transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=True,
                                  num_workers=args.workers)

    testset = dataloader(root='./data',
                         train=False,
                         download=False,
                         transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
            cardinality=args.cardinality,
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            growthRate=args.growthRate,
            compressionRate=args.compressionRate,
            dropRate=args.drop,
        )
    elif args.arch.startswith('wrn'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.startswith('resnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            block_name=args.block_name,
        )
    elif args.arch.startswith('preresnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            block_name=args.block_name,
        )
    elif args.arch.startswith('horesnet'):
        model = models.__dict__[args.arch](num_classes=num_classes,
                                           depth=args.depth,
                                           eta=args.eta,
                                           block_name=args.block_name,
                                           feature_vec=args.feature_vec)
    elif args.arch.startswith('hopreresnet'):
        model = models.__dict__[args.arch](num_classes=num_classes,
                                           depth=args.depth,
                                           eta=args.eta,
                                           block_name=args.block_name,
                                           feature_vec=args.feature_vec)
    elif args.arch.startswith('nagpreresnet'):
        model = models.__dict__[args.arch](num_classes=num_classes,
                                           depth=args.depth,
                                           eta=args.eta,
                                           block_name=args.block_name,
                                           feature_vec=args.feature_vec)
    elif args.arch.startswith('mompreresnet'):
        model = models.__dict__[args.arch](num_classes=num_classes,
                                           depth=args.depth,
                                           eta=args.eta,
                                           block_name=args.block_name,
                                           feature_vec=args.feature_vec)
    elif args.arch.startswith('v2_preresnet'):
        if args.depth == 18:
            block_name = 'basicblock'
            num_blocks = [2, 2, 2, 2]
        elif args.depth == 34:
            block_name = 'basicblock'
            num_blocks = [3, 4, 6, 3]
        elif args.depth == 50:
            block_name = 'bottleneck'
            num_blocks = [3, 4, 6, 3]
        elif args.depth == 101:
            block_name = 'bottleneck'
            num_blocks = [3, 4, 23, 3]
        elif args.depth == 152:
            block_name = 'bottleneck'
            num_blocks = [3, 8, 36, 3]

        model = models.__dict__[args.arch](block_name=block_name,
                                           num_blocks=num_blocks,
                                           num_classes=num_classes)
    else:
        print('Model is specified wrongly - Use standard model')
        model = models.__dict__[args.arch](num_classes=num_classes)

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    criterion = nn.CrossEntropyLoss()
    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    # elif args.optimizer.lower() == 'adam':
    #     optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'adamw':
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay,
                          warmup=args.warmup)
    elif args.optimizer.lower() == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2),
                               weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'srsgd':
        iter_count = 1
        optimizer = SGD_Adaptive(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 iter_count=iter_count,
                                 restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'sradam':
        iter_count = 1
        optimizer = SRNAdam(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'sradamw':
        iter_count = 1
        optimizer = SRAdamW(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            warmup=args.warmup,
                            restarting_iter=args.restart_schedule[0])
    elif args.optimizer.lower() == 'srradam':
        #NOTE: need to double-check this
        iter_count = 1
        optimizer = SRRAdam(model.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, args.beta2),
                            iter_count=iter_count,
                            weight_decay=args.weight_decay,
                            warmup=args.warmup,
                            restarting_iter=args.restart_schedule[0])

    # Resume
    title = 'cifar-10-' + args.arch
    logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
    logger.set_names([
        'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'
    ])

    schedule_index = 1
    # Resume
    title = '%s-' % args.dataset + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        # args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if args.optimizer.lower() == 'srsgd' or args.optimizer.lower(
        ) == 'sradam' or args.optimizer.lower(
        ) == 'sradamw' or args.optimizer.lower() == 'srradam':
            iter_count = optimizer.param_groups[0]['iter_count']
        # schedule_index = checkpoint['schedule_index']
        schedule_index = 3
        state['lr'] = optimizer.param_groups[0]['lr']
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, start_epoch,
                                   use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val

    for epoch in range(start_epoch, args.epochs):
        if args.optimizer.lower() == 'srsgd':
            if epoch == 161:
                start_decay_restarting_iter = args.restart_schedule[
                    schedule_index] - 1
                current_lr = args.lr * (args.gamma**schedule_index)

            if epoch in args.schedule:
                current_lr = args.lr * (args.gamma**schedule_index)
                current_restarting_iter = args.restart_schedule[schedule_index]
                optimizer = SGD_Adaptive(
                    model.parameters(),
                    lr=current_lr,
                    weight_decay=args.weight_decay,
                    iter_count=iter_count,
                    restarting_iter=current_restarting_iter)
                schedule_index += 1

            if epoch >= 161:
                current_restarting_iter = start_decay_restarting_iter * (
                    args.epochs - epoch - 1) / (args.epochs - 162) + 1
                optimizer = SGD_Adaptive(
                    model.parameters(),
                    lr=current_lr,
                    weight_decay=args.weight_decay,
                    iter_count=iter_count,
                    restarting_iter=current_restarting_iter)

        else:
            adjust_learning_rate(optimizer, epoch)

        logger.file.write('\nEpoch: [%d | %d] LR: %f' %
                          (epoch + 1, args.epochs, state['lr']))

        if args.optimizer.lower() == 'srsgd' or args.optimizer.lower(
        ) == 'sradam' or args.optimizer.lower(
        ) == 'sradamw' or args.optimizer.lower() == 'srradam':
            train_loss, train_acc, iter_count = train(trainloader, model,
                                                      criterion, optimizer,
                                                      epoch, use_cuda, logger)
        else:
            train_loss, train_acc = train(trainloader, model, criterion,
                                          optimizer, epoch, use_cuda, logger)

        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   use_cuda, logger)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])

        writer.add_scalars('train_loss', {args.model_name: train_loss}, epoch)
        writer.add_scalars('test_loss', {args.model_name: test_loss}, epoch)
        writer.add_scalars('train_acc', {args.model_name: train_acc}, epoch)
        writer.add_scalars('test_acc', {args.model_name: test_acc}, epoch)

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'schedule_index': schedule_index,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            epoch,
            checkpoint=args.checkpoint)

    logger.file.write('Best acc:%f' % best_acc)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)

    with open("./all_results.txt", "a") as f:
        fcntl.flock(f, fcntl.LOCK_EX)
        f.write("%s\n" % args.checkpoint)
        f.write("best_acc %f\n\n" % best_acc)
        fcntl.flock(f, fcntl.LOCK_UN)
Example #2
0
def train(
        train_data,
        exp_dir=datetime.now().strftime("corrector_model/%Y-%m-%d_%H%M"),
        learning_rate=0.00005,
        rsize=10,
        epochs=1,
        checkpoint_path='',
        seed=6548,
        batch_size=4,
        edge_loss=False,
        model_type='cnet',
        model_cap='normal',
        optimizer_type='radam',
        reset_optimizer=True,  # if true, does not load optimizer chekcpoints
        safe_descent=True,
        activation_type='mish',
        activation_args={},
        io=None,
        dynamic_lr=True,
        dropout=0,
        rotations=False,
        use_batch_norm=True,
        batch_norm_momentum=None,
        batch_norm_affine=True,
        use_gc=True,
        no_lr_schedule=False,
        diff_features_only=False):

    start_time = time.time()

    io.cprint("-------------------------------------------------------" +
              "\nexport dir = " + '/checkpoints/' + exp_dir +
              "\nbase_learning_rate = " + str(learning_rate) +
              "\nuse_batch_norm = " + str(use_batch_norm) +
              "\nbatch_norm_momentum = " + str(batch_norm_momentum) +
              "\nbatch_norm_affine = " + str(batch_norm_affine) +
              "\nno_lr_schedule = " + str(no_lr_schedule) + "\nuse_gc = " +
              str(use_gc) + "\nrsize = " + str(rsize) + "\npython_version: " +
              sys.version + "\ntorch_version: " + torch.__version__ +
              "\nnumpy_version: " + np.version.version + "\nmodel_type: " +
              model_type + "\nmodel_cap: " + model_cap + "\noptimizer: " +
              optimizer_type + "\nactivation_type: " + activation_type +
              "\nsafe_descent: " + str(safe_descent) + "\ndynamic_lr: " +
              str(dynamic_lr) + "\nrotations: " + str(rotations) +
              "\nepochs = " + str(epochs) +
              (("\ncheckpoint = " + checkpoint_path) if
               (checkpoint_path != None and checkpoint_path != '') else '') +
              "\nseed = " + str(seed) + "\nbatch_size = " + str(batch_size) +
              "\n#train_data = " +
              str(sum([bin.size(0) for bin in train_data["train_bins"]])) +
              "\n#test_data = " + str(len(train_data["test_samples"])) +
              "\n#validation_data = " + str(len(train_data["val_samples"])) +
              "\nedge_loss = " + str(edge_loss) +
              "\n-------------------------------------------------------" +
              "\nstart_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") +
              "\n-------------------------------------------------------")

    # initialize torch & cuda ---------------------------------------------------------------------

    torch.manual_seed(seed)
    np.random.seed(seed)

    device = utils.getDevice(io)

    # extract train- & test data (and move to device) --------------------------------------------

    # train_bins = [bin.float().to(device) for bin in train_data["train_bins"]]
    # test_samples = [sample.float().to(device) for sample in train_data["test_samples"]]
    # val_samples = [sample.float().to(device) for sample in train_data["val_samples"]]

    train_bins = [bin.float() for bin in train_data["train_bins"]]
    test_samples = [sample.float() for sample in train_data["test_samples"]]
    val_samples = [sample.float() for sample in train_data["val_samples"]]

    # Initialize Model ------------------------------------------------------------------------------

    model_args = {
        'model_type': model_type,
        'model_cap': model_cap,
        'input_channels': test_samples[0].size(1),
        'output_channels': test_samples[0].size(1),
        'rsize': rsize,
        'emb_dims': 1024,
        'activation_type': activation_type,
        'activation_args': activation_args,
        'dropout': dropout,
        'batch_norm': use_batch_norm,
        'batch_norm_affine': batch_norm_affine,
        'batch_norm_momentum': batch_norm_momentum,
        'diff_features_only': diff_features_only
    }

    model = getModel(model_args).to(device)

    # init optimizer & scheduler -------------------------------------------------------------------

    lookahead_sync_period = 6

    optimizer = None
    if optimizer_type == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=learning_rate,
                          betas=(0.9, 0.999),
                          eps=1e-8,
                          use_gc=use_gc)
    elif optimizer_type == 'lookahead':
        optimizer = Ranger(model.parameters(),
                           lr=learning_rate,
                           alpha=0.9,
                           k=lookahead_sync_period)

    # make sure that either a LR schedule is given or dynamic LR is enabled
    assert dynamic_lr or not no_lr_schedule

    scheduler = None if no_lr_schedule else MultiplicativeLR(
        optimizer, lr_lambda=MultiplicativeAnnealing(epochs))

    # set train settings & load previous model state ------------------------------------------------------------

    checkpoint = getEmptyCheckpoint()
    last_epoch = 0

    if (checkpoint_path != None and checkpoint_path != ''):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'][-1])
        if not reset_optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'][-1])
        last_epoch = len(checkpoint['model_state_dict'])
        print('> loaded checkpoint! (%d epochs)' % (last_epoch))

    checkpoint['train_settings'].append({
        'learning_rate':
        learning_rate,
        'scheduler':
        scheduler,
        'epochs':
        epochs,
        'seed':
        seed,
        'batch_size':
        batch_size,
        'edge_loss':
        edge_loss,
        'optimizer':
        optimizer_type,
        'safe_descent:':
        str(safe_descent),
        'dynamic_lr':
        str(dynamic_lr),
        'rotations':
        str(rotations),
        'train_data_count':
        sum([bin.size(0) for bin in train_data["train_bins"]]),
        'test_data_count':
        len(train_data["test_samples"]),
        'validation_data_count':
        len(train_data["val_samples"]),
        'model_args':
        model_args
    })

    # set up report interval (for logging) and batch size -------------------------------------------------------------------

    report_interval = 100
    loss_function = torch.nn.MSELoss(reduction='mean')

    # begin training ###########################################################################################################################

    io.cprint("\nBeginning Training..\n")

    for epoch in range(last_epoch + 1, last_epoch + epochs + 1):

        io.cprint(
            "Epoch: %d ------------------------------------------------------------------------------------------"
            % (epoch))
        io.cprint("Current LR: %.10f" % (optimizer.param_groups[0]['lr']))

        model.train()
        optimizer.zero_grad()

        checkpoint['train_batch_loss'].append([])
        checkpoint['train_batch_N'].append([])
        checkpoint['train_batch_lr_adjust'].append([])
        checkpoint['train_batch_loss_reduction'].append([])
        checkpoint['lr'].append(optimizer.param_groups[0]['lr'])

        # draw random batches from random bins
        binbatches = utils.drawBinBatches([bin.size(0) for bin in train_bins],
                                          batchsize=batch_size)

        checkpoint['train_batch_N'][-1] = [
            train_bins[bin_id][batch_ids].size(1)
            for (bin_id, batch_ids) in binbatches
        ]

        failed_loss_optims = 0
        cum_lr_adjust_fac = 0
        cum_loss_reduction = 0

        # pre-compute random rotations if needed
        batch_rotations = [None] * len(binbatches)
        if rotations:
            start_rotations = time.time()
            batch_rotations = torch.zeros(
                (len(binbatches), batch_size, test_samples[0].size(1),
                 test_samples[0].size(1)),
                device=device)
            for i in range(len(binbatches)):
                for j in range(batch_size):
                    batch_rotations[i, j] = utils.getRandomRotation(
                        test_samples[0].size(1), device=device)
            print("created batch rotations (%ds)" %
                  (time.time() - start_rotations))

        b = 0  # batch counter

        train_start = time.time()

        for (bin_id, batch_ids) in binbatches:

            b += 1

            # print ("handling batch %d" % (b))

            # prediction & loss ----------------------------------------

            batch_sample = train_bins[bin_id][batch_ids].to(
                model.base.device)  # size: (B x N x d x 2)

            batch_loss = getBatchLoss(model,
                                      batch_sample,
                                      loss_function,
                                      edge_loss=edge_loss,
                                      rotations=batch_rotations[b - 1])
            batch_loss.backward()

            checkpoint['train_batch_loss'][-1].append(batch_loss.item())

            new_loss = 0.0
            lr_adjust = 1.0
            loss_reduction = 0.0

            # if safe descent is enabled, try to optimize the descent step so that a reduction in loss is guaranteed
            if safe_descent:

                # create backups to restore states before the optimizer step
                model_state_backup = copy.deepcopy(model.state_dict())
                opt_state_backup = copy.deepcopy(optimizer.state_dict())

                # make an optimizer step
                optimizer.step()

                # in each itearation, check if the optimzer gave an improvement
                # if not, restore the original states, reduce the learning rate and try again
                # no gradient needed for the plain loss calculation
                with torch.no_grad():
                    for i in range(10):

                        new_loss = getBatchLoss(
                            model,
                            batch_sample,
                            loss_function,
                            edge_loss=edge_loss,
                            rotations=batch_rotations[b - 1]).item()

                        # if the model performs better now we continue, if not we try a smaller learning step
                        if (new_loss < batch_loss.item()):
                            # print("lucky! (%f -> %f) reduction: %.4f%%" % (batch_loss.item(), new_loss, 100 * (batch_loss.item()-new_loss) / batch_loss.item()))
                            break
                        else:
                            # print("try again.. (%f -> %f)" % (batch_loss.item(), new_loss))
                            model.load_state_dict(model_state_backup)
                            optimizer.load_state_dict(opt_state_backup)
                            lr_adjust *= 0.7
                            optimizer.step(lr_adjust=lr_adjust)

                loss_reduction = 100 * (batch_loss.item() -
                                        new_loss) / batch_loss.item()

                if new_loss >= batch_loss.item():
                    failed_loss_optims += 1
                else:
                    cum_lr_adjust_fac += lr_adjust
                    cum_loss_reduction += loss_reduction

            else:

                cum_lr_adjust_fac += lr_adjust
                optimizer.step()

            checkpoint['train_batch_lr_adjust'][-1].append(lr_adjust)
            checkpoint['train_batch_loss_reduction'][-1].append(loss_reduction)

            # reset gradients
            optimizer.zero_grad()

            # statistic caluclation and output -------------------------

            if b % report_interval == 0:

                last_100_loss = sum(checkpoint['train_batch_loss'][-1]
                                    [b - report_interval:b]) / report_interval
                improvement_indicator = '+' if epoch > 1 and last_100_loss < checkpoint[
                    'train_loss'][-1] else ''

                io.cprint(
                    '  Batch %4d to %4d | loss: %.10f%1s | av. dist. per neighbor: %.10f | E%3d | T:%5ds | Failed Optims: %3d (%05.2f%%) | Av. Adjust LR: %.6f | Av. Loss Reduction: %07.4f%%'
                    % (b - (report_interval - 1), b, last_100_loss,
                       improvement_indicator, np.sqrt(last_100_loss), epoch,
                       time.time() - train_start, failed_loss_optims, 100 *
                       (failed_loss_optims / report_interval),
                       (cum_lr_adjust_fac /
                        (report_interval - failed_loss_optims)
                        if failed_loss_optims < report_interval else -1),
                       (cum_loss_reduction /
                        (report_interval - failed_loss_optims)
                        if failed_loss_optims < report_interval else -1)))

                failed_loss_optims = 0
                cum_lr_adjust_fac = 0
                cum_loss_reduction = 0

        checkpoint['train_loss'].append(
            sum(checkpoint['train_batch_loss'][-1]) / b)
        checkpoint['train_time'].append(time.time() - train_start)

        io.cprint(
            '----\n  TRN | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f'
            % (checkpoint['train_time'][-1], checkpoint['train_loss'][-1],
               np.sqrt(checkpoint['train_loss'][-1])))

        torch.cuda.empty_cache()

        ####################
        # Test & Validation
        ####################

        with torch.no_grad():

            if use_batch_norm:

                model.eval_bn()

                eval_bn_start = time.time()

                # run through all train samples again to accumulate layer-wise input distribution statistics (mean and variance) with fixed weights
                # these statistics are later used for the BatchNorm layers during inference
                for (bin_id, batch_ids) in binbatches:
                    input = train_bins[bin_id][batch_ids][:, :, :, 0].squeeze(
                        -1)  # size: (B x N x d)
                    model(input.transpose(1,
                                          2).to(model.base.device)).transpose(
                                              1, 2)  # size: (B x N x d)

                io.cprint('Accumulated BN Layer statistics (%ds)' %
                          (time.time() - eval_bn_start))

            model.eval()

            test_start = time.time()

            test_loss = getTestLoss(model,
                                    test_samples,
                                    loss_function,
                                    edge_loss=edge_loss)

            checkpoint['test_loss'].append(test_loss)
            checkpoint['test_time'].append(time.time() - test_start)

            io.cprint(
                '  TST | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f'
                % (checkpoint['test_time'][-1], checkpoint['test_loss'][-1],
                   np.sqrt(checkpoint['test_loss'][-1])))

            val_start = time.time()

            val_loss = getTestLoss(model,
                                   val_samples,
                                   loss_function,
                                   edge_loss=edge_loss)

            checkpoint['val_loss'].append(val_loss)
            checkpoint['val_time'].append(time.time() - val_start)

            io.cprint(
                '  VAL | time: %5ds | loss: %.10f| av. dist. per neighbor: %.10f'
                % (checkpoint['val_time'][-1], checkpoint['val_loss'][-1],
                   np.sqrt(checkpoint['val_loss'][-1])))

        ####################
        # Scheduler Step
        ####################

        if not no_lr_schedule:
            scheduler.step()

        if epoch > 1 and dynamic_lr and sum(
                checkpoint['train_batch_lr_adjust'][-1]) > 0:
            io.cprint("----\n  dynamic lr adjust: %.10f" %
                      (0.5 *
                       (1 + sum(checkpoint['train_batch_lr_adjust'][-1]) /
                        len(checkpoint['train_batch_lr_adjust'][-1]))))
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5 * (
                    1 + sum(checkpoint['train_batch_lr_adjust'][-1]) /
                    len(checkpoint['train_batch_lr_adjust'][-1]))

        # Save model and optimizer state ..
        checkpoint['model_state_dict'].append(copy.deepcopy(
            model.state_dict()))
        checkpoint['optimizer_state_dict'].append(
            copy.deepcopy(optimizer.state_dict()))

        torch.save(checkpoint, exp_dir + '/corrector_checkpoints.t7')

    io.cprint("\n-------------------------------------------------------" +
              ("\ntotal_time: %.2fh" % ((time.time() - start_time) / 3600)) +
              ("\ntrain_time: %.2fh" %
               (sum(checkpoint['train_time']) / 3600)) +
              ("\ntest_time: %.2fh" % (sum(checkpoint['test_time']) / 3600)) +
              ("\nval_time: %.2fh" % (sum(checkpoint['val_time']) / 3600)) +
              "\n-------------------------------------------------------" +
              "\nend_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") +
              "\n-------------------------------------------------------")
Example #3
0
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100

    trainset = dataloader(root='./data',
                          train=True,
                          download=True,
                          transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=True,
                                  num_workers=args.workers)

    testset = dataloader(root='./data',
                         train=False,
                         download=False,
                         transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
            cardinality=args.cardinality,
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            growthRate=args.growthRate,
            compressionRate=args.compressionRate,
            dropRate=args.drop,
        )
    elif args.arch.startswith('wrn'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.endswith('resnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            block_name=args.block_name,
        )
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    criterion = nn.CrossEntropyLoss()
    if args.optimizer.lower() == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
    # elif args.optimizer.lower() == 'adam':
    #     optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay)
    elif args.optimizer.lower() == 'adamw':
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          betas=(args.beta1, args.beta2),
                          weight_decay=args.weight_decay,
                          warmup=args.warmup)
    # Resume
    title = 'cifar-10-' + args.arch
    # if args.resume:
    #     # Load checkpoint.
    #     print('==> Resuming from checkpoint..')
    #     assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
    #     args.checkpoint = os.path.dirname(args.resume)
    #     checkpoint = torch.load(args.resume)
    #     best_acc = checkpoint['best_acc']
    #     start_epoch = checkpoint['epoch']
    #     model.load_state_dict(checkpoint['state_dict'])
    #     optimizer.load_state_dict(checkpoint['optimizer'])
    #     logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    # else:
    logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
    logger.set_names([
        'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'
    ])

    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, start_epoch,
                                   use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer,
                                      epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   use_cuda)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])
        # writer.add_scalars('loss_tracking/train_loss', {args.model_name: train_loss}, epoch)
        # writer.add_scalars('loss_tracking/test_loss', {args.model_name: test_loss}, epoch)
        # writer.add_scalars('loss_tracking/train_acc', {args.model_name: train_acc}, epoch)
        # writer.add_scalars('loss_tracking/test_acc', {args.model_name: test_acc}, epoch)

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
Example #4
0
def main_worker(gpu, ngpus_per_node, args, writer):
    global best_f1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
        model.fc = nn.Linear(512 * 1, 4) # replace final classifier
    else:
        print("=> creating model '{}'".format(args.arch))
        if args.arch == 'residual_attention_network':
            from model.residual_attention_network import ResidualAttentionModel_92
            model = ResidualAttentionModel_92(num_classes=4)
        else:
            model = models.__dict__[args.arch](num_classes=4)

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss(reduction='none').cuda(args.gpu)
    
    # Better Adam optimizer: https://github.com/LiyuanLucasLiu/RAdam
    optimizer = RAdam(model.parameters(),lr=args.lr)

    # optionally resume from a checkpoint
    checkpoint = None
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_f1 = checkpoint['best_f1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_f1 = best_f1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    train_data = args.train_data
    val_data = args.val_data
    
    if not args.evaluate:
        if not (os.path.isfile(train_data) and os.path.splitext(train_data)[-1] == '.csv'):
            RoofImages.to_csv_datasource(train_data,csv_filename='tmp_train_set.csv', calc_perf=True)
            traindir = 'tmp_train_set.csv'
        else:
            traindir = args.train_data
        
        if not (os.path.isfile(val_data) and os.path.splitext(val_data)[-1] == '.csv'):
            RoofImages.to_csv_datasource(val_data,csv_filename='tmp_val_set.csv', calc_perf=True)
            valdir = 'tmp_val_set.csv'
        else:
            valdir = args.val_data
    else:
        if not args.resume:
            print('Evaluation is chosen without resuming from a checkpoint. Please choose a checkpoint to load with the --resume parameter.')
            exit(1)
        
        if not (os.path.isfile(val_data) and os.path.splitext(val_data)[-1] == '.csv'):        
            RoofImages.to_csv_datasource(val_data,csv_filename='tmp_val_set.csv', calc_perf=True)
            valdir = 'tmp_val_set.csv'
        else:
            valdir = args.val_data
        
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if not args.evaluate:
        train_dataset = RoofImages(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ColorJitter(0.4,0.4,0.4,0.4),
                transforms.RandomGrayscale(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        if args.distributed:
            if args.weighted_sampling:
                print ('Warning: Weighted sampling not implemented for distributed training. So no weighted sampling will be performed')
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        else:
            if args.weighted_sampling:
                train_weights = np.array(train_dataset.train_weights)
                train_sampler = torch.utils.data.WeightedRandomSampler(train_weights, len(train_weights), replacement=True)
            else:
                train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    try:
        classes = checkpoint['classes']
    except:
        classes = None

    val_loader = torch.utils.data.DataLoader(
        RoofImages(valdir, transforms.Compose([
            transforms.Resize(256 if args.val_resize is None else args.val_resize),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]),test_mode=args.evaluate, classes=classes),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        f1, results = validate(val_loader, model, criterion, args, 0, writer, test_mode=True)
        results.to_csv(args.result_file)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        idx_vec, loss_vec = train(train_loader, model, criterion, optimizer, epoch, args, writer)

        if args.loss_weighting:
            print ('Weighting losses')
            train_set_order = np.argsort(idx_vec)
            loss_vec = (loss_vec - loss_vec.min())/(loss_vec.max() - loss_vec.min())
            loss_vec_in_order = loss_vec[train_set_order]
            train_sampler.weights = torch.as_tensor(loss_vec_in_order, dtype=torch.double)

        # evaluate on validation set
        f1, results = validate(val_loader, model, criterion, args, epoch, writer)
        results.to_csv(args.result_file)

        # remember best acc@1 and save checkpoint
        is_best = f1 > best_f1
        best_f1 = max(f1, best_f1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_f1': best_f1,
                'optimizer' : optimizer.state_dict(),
                'classes' : train_dataset.classes # Save the exact classnames this checkpoint was created with
            }, is_best=is_best, file_folder = args.log_dir)
Example #5
0
def train(train_data,
          exp_dir=datetime.now().strftime("detector_model/%Y-%m-%d_%H%M"),
          learning_rate=0.00005,
          rsize=10,
          epochs=1,
          checkpoint_path='',
          seed=6548,
          batch_size=4,
          model_type='cnet',
          model_cap='normal',
          optimizer='radam',
          safe_descent=True,
          activation_type='mish',
          activation_args={},
          io=None,
          dynamic_lr=True,
          dropout=0,
          rotations=False,
          use_batch_norm=True,
          batch_norm_momentum=None,
          batch_norm_affine=True,
          use_gc=True,
          no_lr_schedule=False,
          diff_features_only=False,
          scale_min=1,
          scale_max=1,
          noise=0):

    start_time = time.time()

    scale_min = scale_min if scale_min < 1 else 1
    scale_max = scale_max if scale_max > 1 else 1

    io.cprint("-------------------------------------------------------" +
              "\nexport dir = " + '/checkpoints/' + exp_dir +
              "\nbase_learning_rate = " + str(learning_rate) +
              "\nuse_batch_norm = " + str(use_batch_norm) +
              "\nbatch_norm_momentum = " + str(batch_norm_momentum) +
              "\nbatch_norm_affine = " + str(batch_norm_affine) +
              "\nno_lr_schedule = " + str(no_lr_schedule) + "\nuse_gc = " +
              str(use_gc) + "\nrsize = " + str(rsize) + "\npython_version: " +
              sys.version + "\ntorch_version: " + torch.__version__ +
              "\nnumpy_version: " + np.version.version + "\nmodel_type: " +
              model_type + "\nmodel_cap: " + model_cap + "\noptimizer: " +
              optimizer + "\nactivation_type: " + activation_type +
              "\nsafe_descent: " + str(safe_descent) + "\ndynamic_lr: " +
              str(dynamic_lr) + "\nrotations: " + str(rotations) +
              "\nscaling: " + str(scale_min) + " to " + str(scale_max) +
              "\nnoise: " + str(noise) + "\nepochs = " + str(epochs) +
              (("\ncheckpoint = " +
                checkpoint_path) if checkpoint_path != '' else '') +
              "\nseed = " + str(seed) + "\nbatch_size = " + str(batch_size) +
              "\n#train_data = " +
              str(sum([bin.size(0) for bin in train_data["train_bins"]])) +
              "\n#test_data = " + str(len(train_data["test_samples"])) +
              "\n#validation_data = " + str(len(train_data["val_samples"])) +
              "\n-------------------------------------------------------" +
              "\nstart_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") +
              "\n-------------------------------------------------------")

    # initialize torch & cuda ---------------------------------------------------------------------

    torch.manual_seed(seed)
    np.random.seed(seed)

    device = utils.getDevice(io)

    # extract train- & test data (and move to device) --------------------------------------------

    pts = train_data["pts"].to(device)
    val_pts = train_data["val_pts"].to(device)

    train_bins = train_data["train_bins"]
    test_samples = train_data["test_samples"]
    val_samples = train_data["val_samples"]

    # the maximum noise offset for each point is equal to the distance to its nearest neighbor
    max_noise = torch.square(pts[train_data["knn"][:, 0]] -
                             pts).sum(dim=1).sqrt()

    # Initialize Model ------------------------------------------------------------------------------

    model_args = {
        'model_type': model_type,
        'model_cap': model_cap,
        'input_channels': pts.size(1),
        'output_channels': 2,
        'rsize': rsize,
        'emb_dims': 1024,
        'activation_type': activation_type,
        'activation_args': activation_args,
        'dropout': dropout,
        'batch_norm': use_batch_norm,
        'batch_norm_affine': batch_norm_affine,
        'batch_norm_momentum': batch_norm_momentum,
        'diff_features_only': diff_features_only
    }

    model = getModel(model_args).to(device)

    # init optimizer & scheduler -------------------------------------------------------------------

    lookahead_sync_period = 6

    opt = None
    if optimizer == 'radam':
        opt = RAdam(model.parameters(),
                    lr=learning_rate,
                    betas=(0.9, 0.999),
                    eps=1e-8,
                    use_gc=use_gc)
    elif optimizer == 'lookahead':
        opt = Ranger(model.parameters(),
                     lr=learning_rate,
                     alpha=0.9,
                     k=lookahead_sync_period)

    # make sure that either a LR schedule is given or dynamic LR is enabled
    assert dynamic_lr or not no_lr_schedule

    scheduler = None if no_lr_schedule else MultiplicativeLR(
        opt, lr_lambda=MultiplicativeAnnealing(epochs))

    # set train settings & load previous model state ------------------------------------------------------------

    checkpoint = getEmptyCheckpoint()
    last_epoch = 0

    if (checkpoint_path != ''):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'][-1])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'][-1])
        last_epoch = len(checkpoint['model_state_dict'])
        print('> loaded checkpoint! (%d epochs)' % (last_epoch))

    checkpoint['train_settings'].append({
        'learning_rate':
        learning_rate,
        'scheduler':
        scheduler,
        'epochs':
        epochs,
        'seed':
        seed,
        'batch_size':
        batch_size,
        'optimizer':
        optimizer,
        'safe_descent:':
        str(safe_descent),
        'dynamic_lr':
        str(dynamic_lr),
        'rotations':
        str(rotations),
        'scale_min':
        scale_min,
        'scale_max':
        scale_max,
        'noise':
        noise,
        'train_data_count':
        sum([bin.size(0) for bin in train_data["train_bins"]]),
        'test_data_count':
        len(train_data["test_samples"]),
        'validation_data_count':
        len(train_data["val_samples"]),
        'model_args':
        model_args
    })

    # calculate class weights ---------------------------------------------------------------------

    av_c1_freq = sum([
        torch.sum(bin[:, :, 1]).item() for bin in train_data["train_bins"]
    ]) / sum([bin[:, :, 1].numel() for bin in train_data["train_bins"]])
    class_weights = torch.tensor([av_c1_freq,
                                  1 - av_c1_freq]).float().to(device)

    io.cprint("\nC0 Weight: %.4f" % (class_weights[0].item()))
    io.cprint("C1 Weight: %.4f" % (class_weights[1].item()))

    # Adjust Weights in favor of C1 (edge:true class)
    # class_weights[0] = class_weights[0] / 2
    # class_weights[1] = 1 - class_weights[0]
    # io.cprint("\nAdjusted C0 Weight: %.4f" % (class_weights[0].item()))
    # io.cprint("Adjusted C1 Weight: %.4f" % (class_weights[1].item()))

    # set up report interval (for logging) and batch size -------------------------------------------------------------------

    report_interval = 100

    # begin training ###########################################################################################################################

    io.cprint("\nBeginning Training..\n")

    for epoch in range(last_epoch + 1, last_epoch + epochs + 1):

        io.cprint(
            "Epoch: %d ------------------------------------------------------------------------------------------"
            % (epoch))
        io.cprint("Current LR: %.10f" % (opt.param_groups[0]['lr']))

        model.train()
        opt.zero_grad()

        checkpoint['train_batch_loss'].append([])
        checkpoint['train_batch_N'].append([])
        checkpoint['train_batch_acc'].append([])
        checkpoint['train_batch_C0_acc'].append([])
        checkpoint['train_batch_C1_acc'].append([])
        checkpoint['train_batch_lr_adjust'].append([])
        checkpoint['train_batch_loss_reduction'].append([])
        checkpoint['lr'].append(opt.param_groups[0]['lr'])

        # draw random batches from random bins
        binbatches = utils.drawBinBatches([bin.size(0) for bin in train_bins],
                                          batchsize=batch_size)

        checkpoint['train_batch_N'][-1] = [
            train_bins[bin_id][batch_ids].size(1)
            for (bin_id, batch_ids) in binbatches
        ]

        failed_loss_optims = 0
        cum_lr_adjust_fac = 0
        cum_loss_reduction = 0

        # pre-compute random rotations if needed
        batch_rotations = [None] * len(binbatches)
        if rotations:
            start_rotations = time.time()
            batch_rotations = torch.zeros(
                (len(binbatches), batch_size, pts.size(1), pts.size(1)),
                device=device)
            for i in range(len(binbatches)):
                for j in range(batch_size):
                    batch_rotations[i,
                                    j] = utils.getRandomRotation(pts.size(1),
                                                                 device=device)
            print("created batch rotations (%ds)" %
                  (time.time() - start_rotations))

        b = 0  # batch counter

        train_start = time.time()

        for (bin_id, batch_ids) in binbatches:

            b += 1

            batch_pts_ids = train_bins[bin_id][batch_ids][:, :,
                                                          0]  # size: (B x N)
            batch_input = pts[batch_pts_ids]  # size: (B x N x d)
            batch_target = train_bins[bin_id][batch_ids][:, :, 1].to(
                device)  # size: (B x N)

            if batch_rotations[b - 1] != None:
                batch_input = batch_input.matmul(batch_rotations[b - 1])

            if noise > 0:
                noise_v = torch.randn(
                    batch_input.size(),
                    device=batch_input.device)  # size: (B x N x d)
                noise_v.div_(
                    torch.square(noise_v).sum(
                        dim=2).sqrt()[:, :, None])  # norm to unit vectors
                batch_input.addcmul(noise_v,
                                    max_noise[batch_pts_ids][:, :, None],
                                    value=noise)

            if scale_min < 1 or scale_max > 1:
                # batch_scales = scale_min + torch.rand(batch_input.size(0), device=batch_input.device) * (scale_max - scale_min)
                batch_scales = torch.rand(batch_input.size(0),
                                          device=batch_input.device)
                batch_scales.mul_(scale_max - scale_min)
                batch_scales.add_(scale_min)
                batch_input.mul(batch_scales[:, None, None])

            batch_input = batch_input.transpose(1, 2)  # size: (B x d x N)

            # prediction & loss ----------------------------------------

            batch_prediction = model(batch_input).transpose(
                1, 2)  # size: (B x N x 2)
            batch_loss = cross_entropy(batch_prediction.reshape(-1, 2),
                                       batch_target.view(-1),
                                       class_weights,
                                       reduction='mean')
            batch_loss.backward()

            checkpoint['train_batch_loss'][-1].append(batch_loss.item())

            new_loss = 0.0
            lr_adjust = 1.0
            loss_reduction = 0.0

            # if safe descent is enabled, try to optimize the descent step so that a reduction in loss is guaranteed
            if safe_descent:

                # create backups to restore states before the optimizer step
                model_state_backup = copy.deepcopy(model.state_dict())
                opt_state_backup = copy.deepcopy(opt.state_dict())

                # make an optimizer step
                opt.step()

                # in each itearation, check if the optimzer gave an improvement
                # if not, restore the original states, reduce the learning rate and try again
                # no gradient needed for the plain loss calculation
                with torch.no_grad():
                    for i in range(10):

                        # new_batch_prediction = model(batch_input).transpose(1,2).contiguous()
                        new_batch_prediction = model(batch_input).transpose(
                            1, 2)
                        new_loss = cross_entropy(new_batch_prediction.reshape(
                            -1, 2),
                                                 batch_target.view(-1),
                                                 class_weights,
                                                 reduction='mean').item()

                        # if the model performs better now we continue, if not we try a smaller learning step
                        if (new_loss < batch_loss.item()):
                            # print("lucky! (%f -> %f) reduction: %.4f%%" % (batch_loss.item(), new_loss, 100 * (batch_loss.item()-new_loss) / batch_loss.item()))
                            break
                        else:
                            # print("try again.. (%f -> %f)" % (batch_loss.item(), new_loss))
                            model.load_state_dict(model_state_backup)
                            opt.load_state_dict(opt_state_backup)
                            lr_adjust *= 0.7
                            opt.step(lr_adjust=lr_adjust)

                loss_reduction = 100 * (batch_loss.item() -
                                        new_loss) / batch_loss.item()

                if new_loss >= batch_loss.item():
                    failed_loss_optims += 1
                else:
                    cum_lr_adjust_fac += lr_adjust
                    cum_loss_reduction += loss_reduction

            else:

                cum_lr_adjust_fac += lr_adjust
                opt.step()

            checkpoint['train_batch_lr_adjust'][-1].append(lr_adjust)
            checkpoint['train_batch_loss_reduction'][-1].append(loss_reduction)

            # reset gradients
            opt.zero_grad()

            # make class prediction and save stats -----------------------

            success_vector = torch.argmax(batch_prediction,
                                          dim=2) == batch_target

            c0_idx = batch_target == 0
            c1_idx = batch_target == 1

            checkpoint['train_batch_acc'][-1].append(
                torch.sum(success_vector).item() / success_vector.numel())
            checkpoint['train_batch_C0_acc'][-1].append(
                torch.sum(success_vector[c0_idx]).item() /
                torch.sum(c0_idx).item())  # TODO handle divsion by zero
            checkpoint['train_batch_C1_acc'][-1].append(
                torch.sum(success_vector[c1_idx]).item() /
                torch.sum(c1_idx).item())  # TODO

            # statistic caluclation and output -------------------------

            if b % report_interval == 0:

                last_100_loss = sum(checkpoint['train_batch_loss'][-1]
                                    [b - report_interval:b]) / report_interval
                last_100_acc = sum(checkpoint['train_batch_acc'][-1]
                                   [b - report_interval:b]) / report_interval
                last_100_acc_c0 = sum(
                    checkpoint['train_batch_C0_acc'][-1]
                    [b - report_interval:b]) / report_interval
                last_100_acc_c1 = sum(
                    checkpoint['train_batch_C1_acc'][-1]
                    [b - report_interval:b]) / report_interval

                io.cprint(
                    '  Batch %4d to %4d | loss: %.5f%1s| acc: %.4f%1s| C0 acc: %.4f%1s| C1 acc: %.4f%1s| E%3d | T:%5ds | Failed Optims: %3d (%05.2f%%) | Av. Adjust LR: %.6f | Av. Loss Reduction: %07.4f%%'
                    %
                    (b -
                     (report_interval - 1), b, last_100_loss, '+' if epoch > 1
                     and last_100_loss < checkpoint['train_loss'][-1] else '',
                     last_100_acc, '+' if epoch > 1
                     and last_100_acc > checkpoint['train_acc'][-1] else '',
                     last_100_acc_c0, '+' if epoch > 1
                     and last_100_acc_c0 > checkpoint['train_C0_acc'][-1] else
                     '', last_100_acc_c1, '+' if epoch > 1 and last_100_acc_c1
                     > checkpoint['train_C1_acc'][-1] else '', epoch,
                     time.time() - train_start, failed_loss_optims, 100 *
                     (failed_loss_optims / report_interval),
                     (cum_lr_adjust_fac /
                      (report_interval - failed_loss_optims)
                      if failed_loss_optims < report_interval else -1),
                     (cum_loss_reduction /
                      (report_interval - failed_loss_optims)
                      if failed_loss_optims < report_interval else -1)))

                failed_loss_optims = 0
                cum_lr_adjust_fac = 0
                cum_loss_reduction = 0

        checkpoint['train_loss'].append(
            sum(checkpoint['train_batch_loss'][-1]) / b)
        checkpoint['train_acc'].append(
            sum(checkpoint['train_batch_acc'][-1]) / b)
        checkpoint['train_C0_acc'].append(
            sum(checkpoint['train_batch_C0_acc'][-1]) / b)
        checkpoint['train_C1_acc'].append(
            sum(checkpoint['train_batch_C1_acc'][-1]) / b)
        checkpoint['train_time'].append(time.time() - train_start)

        io.cprint(
            '----\n  TRN | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f'
            % (checkpoint['train_time'][-1], checkpoint['train_loss'][-1],
               checkpoint['train_acc'][-1], checkpoint['train_C0_acc'][-1],
               checkpoint['train_C1_acc'][-1]))

        torch.cuda.empty_cache()

        ####################
        # Test & Validation
        ####################

        with torch.no_grad():

            if use_batch_norm:

                model.eval_bn()

                eval_bn_start = time.time()

                # run through all train samples again to accumulate layer-wise input distribution statistics (mean and variance) with fixed weights
                # these statistics are later used for the BatchNorm layers during inference
                for (bin_id, batch_ids) in binbatches:

                    batch_pts_ids = train_bins[bin_id][
                        batch_ids][:, :, 0]  # size: (B xN)
                    batch_input = pts[batch_pts_ids]  # size: (B x N x d)

                    # batch_input = batch_input.transpose(1,2).contiguous()             # size: (B x d x N)
                    batch_input = batch_input.transpose(1,
                                                        2)  # size: (B x d x N)
                    model(batch_input)

                io.cprint('Accumulated BN Layer statistics (%ds)' %
                          (time.time() - eval_bn_start))

            model.eval()

            if len(test_samples) > 0:

                test_start = time.time()

                test_loss, test_acc, test_acc_c0, test_acc_c1 = getTestLoss(
                    pts, test_samples, model, class_weights)

                checkpoint['test_loss'].append(test_loss)
                checkpoint['test_acc'].append(test_acc)
                checkpoint['test_C0_acc'].append(test_acc_c0)
                checkpoint['test_C1_acc'].append(test_acc_c1)

                checkpoint['test_time'].append(time.time() - test_start)

                io.cprint(
                    '  TST | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f'
                    %
                    (checkpoint['test_time'][-1], checkpoint['test_loss'][-1],
                     checkpoint['test_acc'][-1], checkpoint['test_C0_acc'][-1],
                     checkpoint['test_C1_acc'][-1]))

            else:
                io.cprint('  TST | n/a (no samples)')

            if len(val_samples) > 0:

                val_start = time.time()

                val_loss, val_acc, val_acc_c0, val_acc_c1 = getTestLoss(
                    val_pts, val_samples, model, class_weights)

                checkpoint['val_loss'].append(val_loss)
                checkpoint['val_acc'].append(val_acc)
                checkpoint['val_C0_acc'].append(val_acc_c0)
                checkpoint['val_C1_acc'].append(val_acc_c1)

                checkpoint['val_time'].append(time.time() - val_start)

                io.cprint(
                    '  VAL | time: %5ds | loss: %.10f | acc: %.4f | C0 acc: %.4f | C1 acc: %.4f'
                    % (checkpoint['val_time'][-1], checkpoint['val_loss'][-1],
                       checkpoint['val_acc'][-1], checkpoint['val_C0_acc'][-1],
                       checkpoint['val_C1_acc'][-1]))

            else:
                io.cprint('  VAL | n/a (no samples)')

        ####################
        # Scheduler Step
        ####################

        if not no_lr_schedule:
            scheduler.step()

        if epoch > 1 and dynamic_lr and sum(
                checkpoint['train_batch_lr_adjust'][-1]) > 0:
            io.cprint("----\n  dynamic lr adjust: %.10f" %
                      (0.5 *
                       (1 + sum(checkpoint['train_batch_lr_adjust'][-1]) /
                        len(checkpoint['train_batch_lr_adjust'][-1]))))
            for param_group in opt.param_groups:
                param_group['lr'] *= 0.5 * (
                    1 + sum(checkpoint['train_batch_lr_adjust'][-1]) /
                    len(checkpoint['train_batch_lr_adjust'][-1]))

        # Save model and optimizer state ..
        checkpoint['model_state_dict'].append(copy.deepcopy(
            model.state_dict()))
        checkpoint['optimizer_state_dict'].append(
            copy.deepcopy(opt.state_dict()))

        torch.save(checkpoint, exp_dir + '/detector_checkpoints.t7')

    io.cprint("\n-------------------------------------------------------" +
              ("\ntotal_time: %.2fh" % ((time.time() - start_time) / 3600)) +
              ("\ntrain_time: %.2fh" %
               (sum(checkpoint['train_time']) / 3600)) +
              ("\ntest_time: %.2fh" % (sum(checkpoint['test_time']) / 3600)) +
              ("\nval_time: %.2fh" % (sum(checkpoint['val_time']) / 3600)) +
              "\n-------------------------------------------------------" +
              "\nend_time: " + datetime.now().strftime("%Y-%m-%d_%H%M%S") +
              "\n-------------------------------------------------------")
Example #6
0
def train(args,
          log_dir,
          checkpoint_path,
          trainloader,
          testloader,
          tensorboard,
          c,
          model_name,
          ap,
          cuda=True,
          model_params=None):
    loss1_weight = c.train_config['loss1_weight']
    use_mixup = False if 'mixup' not in c.model else c.model['mixup']
    if use_mixup:
        mixup_alpha = 1 if 'mixup_alpha' not in c.model else c.model[
            'mixup_alpha']
        mixup_augmenter = Mixup(mixup_alpha=mixup_alpha)
        print("Enable Mixup with alpha:", mixup_alpha)

    model = return_model(c, model_params)

    if c.train_config['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=c.train_config['learning_rate'],
            weight_decay=c.train_config['weight_decay'])
    elif c.train_config['optimizer'] == 'adamw':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=c.train_config['learning_rate'],
            weight_decay=c.train_config['weight_decay'])
    elif c.train_config['optimizer'] == 'radam':
        optimizer = RAdam(model.parameters(),
                          lr=c.train_config['learning_rate'],
                          weight_decay=c.train_config['weight_decay'])
    else:
        raise Exception("The %s  not is a optimizer supported" %
                        c.train['optimizer'])

    step = 0
    if checkpoint_path is not None:
        print("Continue training from checkpoint: %s" % checkpoint_path)
        try:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            model.load_state_dict(checkpoint['model'])
        except:
            print(" > Partial model initialization.")
            model_dict = model.state_dict()
            model_dict = set_init_dict(model_dict, checkpoint, c)
            model.load_state_dict(model_dict)
            del model_dict
        step = 0
    else:
        print("Starting new training run")
        step = 0

    if c.train_config['lr_decay']:
        scheduler = NoamLR(optimizer,
                           warmup_steps=c.train_config['warmup_steps'],
                           last_epoch=step - 1)
    else:
        scheduler = None
    # convert model from cuda
    if cuda:
        model = model.cuda()

    # define loss function
    if use_mixup:
        criterion = Clip_BCE()
    else:
        criterion = nn.BCELoss()
    eval_criterion = nn.BCELoss(reduction='sum')

    best_loss = float('inf')

    # early stop definitions
    early_epochs = 0

    model.train()
    for epoch in range(c.train_config['epochs']):
        for feature, target in trainloader:

            if cuda:
                feature = feature.cuda()
                target = target.cuda()

            if use_mixup:
                batch_len = len(feature)
                if (batch_len % 2) != 0:
                    batch_len -= 1
                    feature = feature[:batch_len]
                    target = target[:batch_len]

                mixup_lambda = torch.FloatTensor(
                    mixup_augmenter.get_lambda(batch_len)).to(feature.device)
                output = model(feature[:batch_len], mixup_lambda)
                target = do_mixup(target, mixup_lambda)
            else:
                output = model(feature)
            # Calculate loss
            if c.dataset['class_balancer_batch'] and not use_mixup:
                idxs = (target == c.dataset['control_class'])
                loss_control = criterion(output[idxs], target[idxs])
                idxs = (target == c.dataset['patient_class'])
                loss_patient = criterion(output[idxs], target[idxs])
                loss = (loss_control + loss_patient) / 2
            else:
                loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # update lr decay scheme
            if scheduler:
                scheduler.step()
            step += 1

            loss = loss.item()
            if loss > 1e8 or math.isnan(loss):
                print("Loss exploded to %.02f at step %d!" % (loss, step))
                break

            # write loss to tensorboard
            if step % c.train_config['summary_interval'] == 0:
                tensorboard.log_training(loss, step)
                if c.dataset['class_balancer_batch'] and not use_mixup:
                    print("Write summary at step %d" % step, ' Loss: ', loss,
                          'Loss control:', loss_control.item(),
                          'Loss patient:', loss_patient.item())
                else:
                    print("Write summary at step %d" % step, ' Loss: ', loss)

            # save checkpoint file  and evaluate and save sample to tensorboard
            if step % c.train_config['checkpoint_interval'] == 0:
                save_path = os.path.join(log_dir, 'checkpoint_%d.pt' % step)
                torch.save(
                    {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'step': step,
                        'config_str': str(c),
                    }, save_path)
                print("Saved checkpoint to: %s" % save_path)
                # run validation and save best checkpoint
                val_loss = validation(eval_criterion,
                                      ap,
                                      model,
                                      c,
                                      testloader,
                                      tensorboard,
                                      step,
                                      cuda=cuda,
                                      loss1_weight=loss1_weight)
                best_loss, _ = save_best_checkpoint(
                    log_dir, model, optimizer, c, step, val_loss, best_loss,
                    early_epochs
                    if c.train_config['early_stop_epochs'] != 0 else None)

        print('=================================================')
        print("Epoch %d End !" % epoch)
        print('=================================================')
        # run validation and save best checkpoint at end epoch
        val_loss = validation(eval_criterion,
                              ap,
                              model,
                              c,
                              testloader,
                              tensorboard,
                              step,
                              cuda=cuda,
                              loss1_weight=loss1_weight)
        best_loss, early_epochs = save_best_checkpoint(
            log_dir, model, optimizer, c, step, val_loss, best_loss,
            early_epochs if c.train_config['early_stop_epochs'] != 0 else None)
        if c.train_config['early_stop_epochs'] != 0:
            if early_epochs is not None:
                if early_epochs >= c.train_config['early_stop_epochs']:
                    break  # stop train
    return best_loss