def validate(epoch, model, ema=None):
    """
    Evaluates the cross entropy between p_data and p_model.
    """
    bpd_meter = utils.AverageMeter()
    ce_meter = utils.AverageMeter()

    if ema is not None:
        ema.swap()

    update_lipschitz(model)

    model.eval()

    correct = 0
    total = 0

    start = time.time()
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(test_loader)):
            x = x.to(device)
            bpd, _, _ = compute_loss(x, model)
            bpd_meter.update(bpd.item(), x.size(0))

    val_time = time.time() - start

    if ema is not None:
        ema.swap()
    s = 'Epoch: [{0}]\tTime {1:.2f} | Test Nats {bpd_meter.avg:.4f}'.format(
        epoch, val_time, bpd_meter=bpd_meter)
    logger.info(s)
    return bpd_meter.avg
Пример #2
0
def train_handcraft(args, train_loader, valid_loader, index_loader,
                    valid_dataset, index_dataset, save_root, writer):
    ext = handcraft_extractor(args)
    start_epoch = 0

    if args.ckpt_path is not None:
        ext.load(args.ckpt_path)
    else:

        batch_time = u.AverageMeter()
        data_time = u.AverageMeter()
        start = time.time()
        pbar = tqdm.tqdm(enumerate(train_loader),
                         desc="Extract local descriptor!")

        if args.train is True:
            for batch_i, data in pbar:

                data_time.update(time.time() - start)
                start = time.time()

                ext.extract_ld(data)

                batch_time.update(time.time() - start)
                start = time.time()

                state_msg = ('Data time: {:0.5f}; Batch time: {:0.5f};'.format(
                    data_time.avg, batch_time.avg))

                pbar.set_description(state_msg)

            ext.build_voca(k=args.cluster)
            ext.extract_vlad()
            filename = os.path.join(save_root, 'ckpt', 'checkpoint.pkl')
            ext.save(filename)

    if (args.valid is True) or (args.valid_sample is True):
        pbar = tqdm.tqdm(enumerate(valid_loader),
                         desc="Extract query descriptor!")
        for batch_i, data in pbar:

            ext.extract_vlad_query(data)

        indexdb, validdb = ext.get_data()
        if args.metric == 0:
            ldm = mt.LocDegThreshMetric(args, indexdb, validdb, index_dataset,
                                        valid_dataset, 0,
                                        os.path.join(save_root, "result"))
    return
def validate(epoch, model, ema=None):
    """
    Evaluates the cross entropy between p_data and p_model.
    """
    bpd_meter = utils.AverageMeter()
    ce_meter = utils.AverageMeter()

    if ema is not None:
        ema.swap()

    update_lipschitz(model)

    model = parallelize(model)
    model.eval()

    correct = 0
    total = 0

    start = time.time()
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(test_loader)):

            x = x.to(device)
            bpd, logits, _, _ = compute_loss(x, model)
            bpd_meter.update(bpd.item(), x.size(0))

            if args.task in ['classification', 'hybrid']:
                y = y.to(device)
                loss = criterion(logits, y)
                ce_meter.update(loss.item(), x.size(0))
                _, predicted = logits.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()
    val_time = time.time() - start

    if ema is not None:
        ema.swap()
    s = 'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {bpd_meter.avg:.4f}'.format(
        epoch, val_time, bpd_meter=bpd_meter)
    if args.task in ['classification', 'hybrid']:
        s += ' | CE {:.4f} | Acc {:.2f}'.format(ce_meter.avg,
                                                100 * correct / total)
    logger.info(s)
    return bpd_meter.avg
Пример #4
0
    def train_on_dataset(self, data_loader, models, criterions, optimizers,
                         epoch, logs, **kwargs):
        """
            train on dataset for one epoch
        """

        loss_meters = [utils.AverageMeter() for i in range(len(models))]
        top1_meters = [utils.AverageMeter() for i in range(len(models))]

        for model in models:
            model.train()

        for i, (input_, target) in enumerate(data_loader):
            input_, target = self.to_cuda(input_, target)
            self.train_on_batch(input_, target, models, criterions, optimizers,
                                logs, loss_meters, top1_meters, **kwargs)

        self.write_log(logs, loss_meters, top1_meters, epoch, mode="train")

        return logs
Пример #5
0
    def validate_on_dataset(self, data_loader, models, criterions, epoch, logs,
                            **kwargs):
        """
            validate on dataset
        """

        loss_meters = [utils.AverageMeter() for i in range(len(models))]
        top1_meters = [utils.AverageMeter() for i in range(len(models))]

        for model in models:
            model.eval()

        for i, (input_, target) in enumerate(data_loader):
            input_, target = self.to_cuda(input_, target)
            self.validate_on_batch(input_, target, models, criterions, logs,
                                   loss_meters, top1_meters)

        self.write_log(logs, loss_meters, top1_meters, epoch, mode="test")

        return logs
Пример #6
0
def validate(epoch, model, data_loader, ema, device):
    """
    Evaluates the cross entropy between p_data and p_model.
    """
    bpd_meter = utils.AverageMeter()

    if ema is not None:
        ema.swap()

    model.eval()

    start = time.time()
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(data_loader)):
            x = x.to(device)
            bpd = compute_loss(x, model)
            bpd_meter.update(bpd.item(), x.size(0))
    val_time = time.time() - start

    if ema is not None:
        ema.swap()

    return val_time, bpd_meter.avg
Пример #7
0
def train(model, trainD, evalD, checkpt=None):
    global ndecs
    optim = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             betas=(0.9, 0.99),
                             weight_decay=args.wd)
    #  sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.nepochs * trainD.N)

    if checkpt is not None:
        optim.load_state_dict(checkpt['optim'])
        ndecs = checkpt['ndecs']

    batch_time = utils.RunningAverageMeter(0.98)
    cg_meter = utils.RunningAverageMeter(0.98)
    gnorm_meter = utils.RunningAverageMeter(0.98)
    train_est_meter = utils.RunningAverageMeter(0.98**args.train_est_freq)

    best_logp = -float('inf')
    itr = 0 if checkpt is None else checkpt['iters']
    n_vals_without_improvement = 0
    model.train()
    while True:
        if itr >= args.nepochs * math.ceil(trainD.N / args.batch_size):
            break
        if 0 < args.early_stopping < n_vals_without_improvement:
            break
        for x in batch_iter(trainD.x, shuffle=True):
            if 0 < args.early_stopping < n_vals_without_improvement:
                break
            end = time.time()
            optim.zero_grad()

            x = cvt(x)
            train_est = [0] if itr % args.train_est_freq == 0 else None
            loss = -model.logp(x, extra=train_est).mean()
            if train_est is not None:
                train_est = train_est[0].mean().detach().item()

            if loss != loss:
                raise ValueError('NaN encountered @ training logp!')

            loss.backward()

            if args.clip_grad == 0:
                parameters = [
                    p for p in model.parameters() if p.grad is not None
                ]
                grad_norm = torch.norm(
                    torch.stack([
                        torch.norm(p.grad.detach(), 2.0) for p in parameters
                    ]), 2.0)
            else:
                grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), args.clip_grad)

            optim.step()
            #  sch.step()

            gnorm_meter.update(float(grad_norm))
            cg_meter.update(sum(flows.CG_ITERS_TRACER))
            flows.CG_ITERS_TRACER.clear()
            batch_time.update(time.time() - end)
            if train_est is not None:
                train_est_meter.update(train_est)

            del loss
            gc.collect()
            torch.clear_autocast_cache()

            if itr % args.log_freq == 0:
                log_message = (
                    'Iter {:06d} | Epoch {:.2f} | Time {batch_time.val:.3f} | '
                    'GradNorm {gnorm_meter.avg:.2f} | CG iters {cg_meter.val} ({cg_meter.avg:.2f}) | '
                    'Train logp {train_logp.val:.6f} ({train_logp.avg:.6f})'.
                    format(itr,
                           float(itr) / (trainD.N / float(args.batch_size)),
                           batch_time=batch_time,
                           gnorm_meter=gnorm_meter,
                           cg_meter=cg_meter,
                           train_logp=train_est_meter))
                logger.info(log_message)

            # Validation loop.
            if itr % args.val_freq == 0:
                with eval_ctx(model, bruteforce=args.brute_val):
                    val_logp = utils.AverageMeter()
                    with tqdm(total=evalD.N) as pbar:
                        # noinspection PyAssignmentToLoopOrWithParameter
                        for x in batch_iter(evalD.x,
                                            batch_size=args.val_batch_size):
                            x = cvt(x)
                            val_logp.update(
                                model.logp(x).mean().item(), x.size(0))
                            pbar.update(x.size(0))
                    if val_logp.avg > best_logp:
                        best_logp = val_logp.avg
                        utils.makedirs(args.save)
                        torch.save(
                            {
                                'args': args,
                                'model': model.state_dict(),
                                'optim': optim.state_dict(),
                                'iters': itr + 1,
                                'ndecs': ndecs,
                            }, save_path)
                        n_vals_without_improvement = 0
                    else:
                        n_vals_without_improvement += 1
                        update_lr(optim, n_vals_without_improvement)

                    log_message = ('[VAL] Iter {:06d} | Val logp {:.6f} | '
                                   'NoImproveEpochs {:02d}/{:02d}'.format(
                                       itr, val_logp.avg,
                                       n_vals_without_improvement,
                                       args.early_stopping))
                    logger.info(log_message)

            itr += 1

    logger.info('Training has finished, yielding the best model...')
    best_checkpt = torch.load(save_path)
    model.load_state_dict(best_checkpt['model'])
    return model
Пример #8
0
            "Resume file provided, but not found... starting from scratch: {}".
            format(args.resume))

    logger.info(flow)
    logger.info("Number of trainable parameters:{}".format(
        count_parameters(flow)))

    ################################################################################
    #                                   Training                                   #
    ################################################################################

    if not args.evaluate:
        flow = train(flow, data.trn, data.val, checkpt)

    ################################################################################
    #                                   Testing                                    #
    ################################################################################

    logger.info('Evaluating model on test set.')
    with eval_ctx(flow, bruteforce=True):
        test_logp = utils.AverageMeter()
        with tqdm(total=data.tst.N) as pbar:
            for itr, x in enumerate(
                    batch_iter(data.tst.x, batch_size=args.test_batch_size)):
                x = cvt(x)
                test_logp.update(flow.logp(x).mean().item(), x.size(0))
                pbar.update(x.size(0))
        log_message = '[TEST] Iter {:06d} | Test logp {:.6f}'.format(
            itr, test_logp.avg)
        logger.info(log_message)
Пример #9
0
        filtered_state_dict = {}
        for k, v in checkpt['state_dict'].items():
            if 'diffeq.diffeq' not in k:
                filtered_state_dict[k.replace('module.', '')] = v
        model.load_state_dict(filtered_state_dict)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(
        count_parameters(model)))

    if not args.evaluate:
        optimizer = Adam(model.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay)

        time_meter = utils.AverageMeter()
        loss_meter = utils.AverageMeter()
        nfef_meter = utils.AverageMeter()
        nfeb_meter = utils.AverageMeter()
        tt_meter = utils.AverageMeter()

        best_loss = float('inf')
        itr = 0
        n_vals_without_improvement = 0
        end = time.time()
        model.train()
        while True:
            if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping:
                break

            for x in batch_iter(data.trn.x, shuffle=True):
Пример #10
0
def train(args, train_loader, valid_loader, index_loader, valid_dataset,
          index_dataset, save_root, writer):
    ext = extractor(args)
    crt = criterion(args)

    optim = get_optimizer(args, ext)
    sche = get_scheduler(args, optim)

    start_epoch = 0

    if args.ckpt_path is not None:
        ckpt = torch.load(args.ckpt_path)
        ext = ckpt['model']
        optim = ckpt['optimizer']
        start_epoch = ckpt['epoch'] + 1

    for epoch in range(start_epoch, args.epochs):
        ext.train()
        sche.step()

        batch_time = u.AverageMeter()
        data_time = u.AverageMeter()
        losses = u.AverageMeter()

        start = time.time()

        pbar = tqdm.tqdm(enumerate(train_loader), desc="Epoch : %d" % epoch)

        size_all = len(train_loader)
        interv = math.floor(size_all / args.save_interval)
        sub_p = 0

        if args.train is True:
            for batch_i, data in pbar:

                image = data['image'].cuda()
                label = data['label'].cuda()

                data_time.update(time.time() - start)
                start = time.time()

                output = ext(image)
                if output.dim() == 1:
                    output = output.unsqueeze(0)

                loss = crt(output, label, args.tuple, args.batch)

                optim.zero_grad()
                loss.backward()
                optim.step()

                losses.update(loss.item())
                batch_time.update(time.time() - start)
                start = time.time()

                writer.add_scalars('train/loss', {'loss': losses.avg},
                                   global_step=epoch * len(train_loader) +
                                   batch_i)

                state_msg = (
                    'Epoch: {:4d}; Loss: {:0.5f}; Data time: {:0.5f}; Batch time: {:0.5f};'
                    .format(epoch, losses.avg, data_time.avg, batch_time.avg))

                pbar.set_description(state_msg)

                if ((batch_i + 1) % interv == 0) or (size_all == batch_i + 1):

                    state = {
                        'epoch': epoch,
                        'loss': losses.avg,
                        'model': ext,
                        'optimizer': optim
                    }
                    filename = os.path.join(
                        save_root, 'ckpt',
                        'checkpoint_subset{:03d}_epoch{:03d}.pth.tar'.format(
                            sub_p, epoch))
                    torch.save(state, filename)
                    sub_p += 1

        if ((args.valid is True) or (args.valid_sample is True)) and (
            (epoch + 1) % args.valid_interval == 0):

            if (args.db_load is not None):
                with open(args.db_load, "rb") as a_file:
                    indexdb = pickle.load(a_file)
            else:
                # #index
                indexdb = make_inferDBandPredict(args,
                                                 index_loader,
                                                 ext,
                                                 epoch,
                                                 tp='index')

            #valid
            validdb = make_inferDBandPredict(args,
                                             valid_loader,
                                             ext,
                                             epoch,
                                             tp='valid')

            if args.db_load is None:
                if (args.extractor >= 4):
                    indexdb['feat'], validdb['feat'] = ext.postprocessing(
                        indexdb['feat'], validdb['feat'])

                if args.pca is True:
                    pca = pp.PCAwhitening(pca_dim=args.pca_dim,
                                          pca_whitening=True)
                    indexdb['feat'] = pca.fit_transform(indexdb['feat'])
                    validdb['feat'] = pca.transform(validdb['feat'])

            if (args.db_save is not None):
                if os.path.isfile(args.db_save) is True:
                    os.remove(args.db_save)
                a_file = open(args.db_save, "wb")
                pickle.dump(indexdb, a_file)
                a_file.close()

            if args.metric == 0:
                ldm = mt.LocDegThreshMetric(args, indexdb, validdb,
                                            index_dataset, valid_dataset,
                                            epoch,
                                            os.path.join(save_root, "result"))

            if args.train is False:
                return

            if args.qualitative:
                return

            for key, value in ldm.items():
                writer.add_scalars('valid/top' + str(args.topk) + "_" + key,
                                   {key: value},
                                   global_step=epoch)

    return
Пример #11
0
    logger.info("saveLocation = {:}".format(args.save))
    logger.info("-------------------------\n")

    begin = time.time()
    end = begin
    best_loss = float('inf')
    best_costs = [0.0] * 3
    best_params = None

    log_msg = (
        '{:5s}  {:6s}  {:7s}   {:9s}  {:9s}  {:9s}  {:9s}     {:9s}  {:9s}  {:9s}  {:9s} '
        .format('iter', ' time', 'lr', 'loss', 'L (L_2)', 'C (loss)',
                'R (HJB)', 'valLoss', 'valL', 'valC', 'valR'))
    logger.info(log_msg)

    timeMeter = utils.AverageMeter()
    clampMax = 2.0
    clampMin = -2.0

    net.train()
    itr = 1
    while itr < args.niters:
        # train
        for data in train_loader:
            images, _ = data
            # flatten images
            x0 = images.view(images.size(0), -1)
            x0 = cvt(x0)
            x0 = autoEnc.encode(x0)  # encode
            x0 = (x0 - autoEnc.mu) / (autoEnc.std + args.eps)  # normalize
Пример #12
0
def validate(epoch, model, gmm, ema=None):
    """
    - Deploys the color normalization on test image dataset
    - Evaluates NMI / CV / SD
    # Evaluates the cross entropy between p_data and p_model.
    """
    print("Starting Validation")
    model = parallelize(model)
    gmm = parallelize(gmm)

    model.to(device)
    gmm.to(device)

    bpd_meter = utils.AverageMeter()
    ce_meter = utils.AverageMeter()

    if ema is not None:
        ema.swap()

    update_lipschitz(model)

    model.eval()
    gmm.eval()

    mu_tmpl = 0
    std_tmpl = 0
    N = 0

    print(
        f"Deploying on {len(train_loader)} batches of {args.batchsize} templates..."
    )
    idx = 0
    for x, y in tqdm(train_loader):
        x = x.to(device)
        ### TEMPLATES ###
        D = x[:, 0, ...].unsqueeze(1)
        D = rescale(D)  # Scale to [0,1] interval
        D = D.repeat(1, args.nclusters, 1, 1)
        with torch.no_grad():
            if isinstance(model, torch.nn.DataParallel):
                z_logp = model.module(D.view(-1, *input_size[1:]),
                                      0,
                                      classify=False)
            else:
                z_logp = model(D.view(-1, *input_size[1:]), 0, classify=False)

            z, delta_logp = z_logp
            if isinstance(gmm, torch.nn.DataParallel):
                logpz, params = gmm.module(
                    z.view(-1, args.nclusters, args.imagesize, args.imagesize),
                    x.permute(0, 2, 3, 1))
            else:
                logpz, params = gmm(
                    z.view(-1, args.nclusters, args.imagesize, args.imagesize),
                    x.permute(0, 2, 3, 1))

        mu, std, gamma = params
        mu = mu.cpu().numpy()
        std = std.cpu().numpy()
        gamma = gamma.cpu().numpy()

        mu = mu[..., np.newaxis]
        std = std[..., np.newaxis]

        mu = np.swapaxes(mu, 0, 1)  # (3,4,1) -> (4,3,1)
        mu = np.swapaxes(mu, 1, 2)  # (4,3,1) -> (4,1,3)
        std = np.swapaxes(std, 0, 1)  # (3,4,1) -> (4,3,1)
        std = np.swapaxes(std, 1, 2)  # (4,3,1) -> (4,1,3)

        N = N + 1
        mu_tmpl = (N - 1) / N * mu_tmpl + 1 / N * mu
        std_tmpl = (N - 1) / N * std_tmpl + 1 / N * std

        if idx == len(train_loader) - 1: break
        idx += 1

    print("Estimated Mu for template(s):")
    print(mu_tmpl)

    print("Estimated Sigma for template(s):")
    print(std_tmpl)

    metrics = dict()
    for tc in range(1, args.nclusters + 1):
        metrics[f'mean_{tc}'] = []
        metrics[f'median_{tc}'] = []
        metrics[f'perc_95_{tc}'] = []
        metrics[f'nmi_{tc}'] = []
        metrics[f'sd_{tc}'] = []
        metrics[f'cv_{tc}'] = []

    print(
        f"Predicting on {len(test_loader)} batches of {args.val_batchsize} templates..."
    )
    idx = 0
    for x_test, y_test in tqdm(test_loader):
        x_test = x_test.to(device)
        ### DEPLOY ###
        D = x_test[:, 0, ...].unsqueeze(1)
        D = rescale(D)  # Scale to [0,1] interval
        D = D.repeat(1, args.nclusters, 1, 1)
        with torch.no_grad():
            if isinstance(model, torch.nn.DataParallel):
                z_logp = model.module(D.view(-1, *input_size[1:]),
                                      0,
                                      classify=False)
            else:
                z_logp = model(D.view(-1, *input_size[1:]), 0, classify=False)

            z, delta_logp = z_logp
            if isinstance(gmm, torch.nn.DataParallel):
                logpz, params = gmm.module(
                    z.view(-1, args.nclusters, args.imagesize, args.imagesize),
                    x_test.permute(0, 2, 3, 1))
            else:
                logpz, params = gmm(
                    z.view(-1, args.nclusters, args.imagesize, args.imagesize),
                    x_test.permute(0, 2, 3, 1))

        mu, std, pi = params
        mu = mu.cpu().numpy()
        std = std.cpu().numpy()
        pi = pi.cpu().numpy()

        mu = mu[..., np.newaxis]
        std = std[..., np.newaxis]

        mu = np.swapaxes(mu, 0, 1)  # (3,4,1) -> (4,3,1)
        mu = np.swapaxes(mu, 1, 2)  # (4,3,1) -> (4,1,3)
        std = np.swapaxes(std, 0, 1)  # (3,4,1) -> (4,3,1)
        std = np.swapaxes(std, 1, 2)  # (4,3,1) -> (4,1,3)

        X_hsd = np.swapaxes(x_test.cpu().numpy(), 1, 2)
        X_hsd = np.swapaxes(X_hsd, 2, 3)

        X_conv = imgtf.image_dist_transform(X_hsd, mu, std, pi, mu_tmpl,
                                            std_tmpl, args)

        ClsLbl = np.argmax(np.asarray(pi), axis=-1)
        ClsLbl = ClsLbl.astype('int32')
        mean_rgb = np.mean(X_conv, axis=-1)
        pdb.set_trace()
        for tc in range(1, args.nclusters + 1):
            msk = ClsLbl == tc
            if not msk.any():
                continue  # skip metric if no class labels are found
            ma = mean_rgb[msk]
            mean = np.mean(ma)
            median = np.median(ma)
            perc = np.percentile(ma, 95)
            nmi = median / perc
            metrics[f'mean_{tc}'].append(mean)
            metrics[f'median_{tc}'].append(median)
            metrics[f'perc_95_{tc}'].append(perc)
            metrics[f'nmi_{tc}'].append(nmi)

        if idx == len(test_loader) - 1: break
        idx += 1

    av_sd = []
    av_cv = []
    for tc in range(1, args.nclusters + 1):
        if len(metrics[f'mean_{tc}']) == 0: continue
        metrics[f'sd_{tc}'] = np.array(metrics[f'nmi_{tc}']).std()
        metrics[f'cv_{tc}'] = np.array(metrics[f'nmi_{tc}']).std() / np.array(
            metrics[f'nmi_{tc}']).mean()
        print(f'sd_{tc}:', metrics[f'sd_{tc}'])
        print(f'cv_{tc}:', metrics[f'cv_{tc}'])
        av_sd.append(metrics[f'sd_{tc}'])
        av_cv.append(metrics[f'cv_{tc}'])

    print(f"Average sd = {np.array(av_sd).mean()}")
    print(f"Average cv = {np.array(av_cv).mean()}")
    import csv
    file = open(f"metrics-{args.train_centers[0]}-{args.val_centers[0]}.csv",
                "w")
    writer = csv.writer(file)
    for key, value in metrics.items():
        writer.writerow([key, value])

    file.close()

    # correct = 0
    # total = 0

    # start = time.time()
    # with torch.no_grad():
    #     for i, (x, y) in enumerate(tqdm(test_loader)):
    #         x = x.to(device)

    #         bpd, logits, _, _ = compute_loss(x, model)
    #         bpd_meter.update(bpd.item(), x.size(0))

    # val_time = time.time() - start

    # if ema is not None:
    #     ema.swap()
    # s = 'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {bpd_meter.avg:.4f}'.format(epoch, val_time, bpd_meter=bpd_meter)
    # if args.task in ['classification', 'hybrid']:
    #     s += ' | CE {:.4f} | Acc {:.2f}'.format(ce_meter.avg, 100 * correct / total)
    # logger.info(s)
    # return bpd_meter.avg

    return
Пример #13
0
def run(args, kwargs):
    # ==================================================================================================================
    # SNAPSHOTS
    # ==================================================================================================================
    args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_')
    args.model_signature = args.model_signature.replace(':', '_')

    if args.automatic_saving == True:
        path = '{}/{}/{}/{}/{}/{}/{}/{}/{}/'.format(args.solver, args.dataset,
                                                    args.layer_type, args.atol,
                                                    args.rtol, args.atol_start,
                                                    args.rtol_start,
                                                    args.warmup_steps,
                                                    args.manual_seed)
    else:
        path = 'test/'

    args.snap_dir = os.path.join(args.out_dir, path)

    if not os.path.exists(args.snap_dir):
        os.makedirs(args.snap_dir)

    # logger
    utils.makedirs(args.snap_dir)
    logger = utils.get_logger(logpath=os.path.join(args.snap_dir, 'logs'),
                              filepath=os.path.abspath(__file__))

    logger.info(args)

    # SAVING
    torch.save(args, args.snap_dir + 'config.config')

    # ==================================================================================================================
    # LOAD DATA
    # ==================================================================================================================
    train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs)

    if not args.evaluate:

        nfef_meter = utils.AverageMeter()
        nfeb_meter = utils.AverageMeter()

        # ==============================================================================================================
        # SELECT MODEL
        # ==============================================================================================================
        # flow parameters and architecture choice are passed on to model through args

        if args.flow == 'no_flow':
            model = VAE.VAE(args)
        elif args.flow == 'planar':
            model = VAE.PlanarVAE(args)
        elif args.flow == 'iaf':
            model = VAE.IAFVAE(args)
        elif args.flow == 'orthogonal':
            model = VAE.OrthogonalSylvesterVAE(args)
        elif args.flow == 'householder':
            model = VAE.HouseholderSylvesterVAE(args)
        elif args.flow == 'triangular':
            model = VAE.TriangularSylvesterVAE(args)
        elif args.flow == 'cnf':
            model = CNFVAE.CNFVAE(args)
        elif args.flow == 'cnf_bias':
            model = CNFVAE.AmortizedBiasCNFVAE(args)
        elif args.flow == 'cnf_hyper':
            model = CNFVAE.HypernetCNFVAE(args)
        elif args.flow == 'cnf_lyper':
            model = CNFVAE.LypernetCNFVAE(args)
        elif args.flow == 'cnf_rank':
            model = CNFVAE.AmortizedLowRankCNFVAE(args)
        else:
            raise ValueError('Invalid flow choice')

        if args.retrain_encoder:
            logger.info(f"Initializing decoder from {args.model_path}")
            dec_model = torch.load(args.model_path)
            dec_sd = {}
            for k, v in dec_model.state_dict().items():
                if 'p_x' in k:
                    dec_sd[k] = v
            model.load_state_dict(dec_sd, strict=False)

        if args.cuda:
            logger.info("Model on GPU")
            model.cuda()

        logger.info(model)
        logger.info("Number of trainable parameters: {}".format(
            count_parameters(model)))

        if args.retrain_encoder:
            parameters = []
            logger.info('Optimizing over:')
            for name, param in model.named_parameters():
                if 'p_x' not in name:
                    logger.info(name)
                    parameters.append(param)
        else:
            parameters = model.parameters()

        optimizer = optim.Adamax(parameters, lr=args.learning_rate, eps=1.e-7)

        # ==================================================================================================================
        # TRAINING
        # ==================================================================================================================
        train_loss = []
        val_loss = []

        # for early stopping
        best_loss = np.inf
        best_bpd = np.inf
        e = 0
        epoch = 0

        train_times = []

        for epoch in range(1, args.epochs + 1):
            atol, rtol = update_tolerances(args, epoch, decay_factors)
            print(atol)
            set_cnf_options(args, atol, rtol, model)

            t_start = time.time()

            if 'cnf' not in args.flow:
                tr_loss = train(epoch, train_loader, model, optimizer, args,
                                logger)
            else:
                tr_loss, nfef_meter, nfeb_meter = train(
                    epoch, train_loader, model, optimizer, args, logger,
                    nfef_meter, nfeb_meter)

            train_loss.append(tr_loss)
            train_times.append(time.time() - t_start)
            logger.info('One training epoch took %.2f seconds' %
                        (time.time() - t_start))

            v_loss, v_bpd = evaluate(val_loader,
                                     model,
                                     args,
                                     logger,
                                     epoch=epoch)

            val_loss.append(v_loss)

            # early-stopping
            if v_loss < best_loss:
                e = 0
                best_loss = v_loss
                if args.input_type != 'binary':
                    best_bpd = v_bpd
                logger.info('->model saved<-')
                torch.save(model, args.snap_dir + 'model.model')
                # torch.save(model, snap_dir + args.flow + '_' + args.architecture + '.model')

            elif (args.early_stopping_epochs > 0) and (epoch >= args.warmup):
                e += 1
                if e > args.early_stopping_epochs:
                    break

            if args.input_type == 'binary':
                logger.info(
                    '--> Early stopping: {}/{} (BEST: loss {:.4f})\n'.format(
                        e, args.early_stopping_epochs, best_loss))

            else:
                logger.info(
                    '--> Early stopping: {}/{} (BEST: loss {:.4f}, bpd {:.4f})\n'
                    .format(e, args.early_stopping_epochs, best_loss,
                            best_bpd))

            if math.isnan(v_loss):
                raise ValueError('NaN encountered!')

        train_loss = np.hstack(train_loss)
        val_loss = np.array(val_loss)

        plot_training_curve(train_loss,
                            val_loss,
                            fname=args.snap_dir + '/training_curve.pdf')

        # training time per epoch
        train_times = np.array(train_times)
        mean_train_time = np.mean(train_times)
        std_train_time = np.std(train_times, ddof=1)
        logger.info('Average train time per epoch: %.2f +/- %.2f' %
                    (mean_train_time, std_train_time))

        # ==================================================================================================================
        # EVALUATION
        # ==================================================================================================================

        logger.info(args)
        logger.info('Stopped after %d epochs' % epoch)
        logger.info('Average train time per epoch: %.2f +/- %.2f' %
                    (mean_train_time, std_train_time))

        final_model = torch.load(args.snap_dir + 'model.model')
        validation_loss, validation_bpd = evaluate(val_loader, final_model,
                                                   args, logger)

    else:
        validation_loss = "N/A"
        validation_bpd = "N/A"
        logger.info(f"Loading model from {args.model_path}")
        final_model = torch.load(args.model_path)

    test_loss, test_bpd = evaluate(test_loader,
                                   final_model,
                                   args,
                                   logger,
                                   testing=True)

    logger.info(
        'FINAL EVALUATION ON VALIDATION SET. ELBO (VAL): {:.4f}'.format(
            validation_loss))
    return atol, rtol


if __name__ == '__main__':

    regularization_fns, regularization_coeffs = create_regularization_fns(args)
    model = build_model_tabular(args, 2, regularization_fns).to(device)
    if args.spectral_norm: add_spectral_norm(model)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    if not args.only_viz_samples:
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

        time_meter = utils.AverageMeter()
        loss_meter = utils.AverageMeter()
        nfef_meter = utils.AverageMeter()
        nfeb_meter = utils.AverageMeter()
        tt_meter = utils.AverageMeter()

        end = time.time()
        best_loss = float('inf')
        model.train()
        for itr in range(1, args.niters + 1):
            atol, rtol = update_tolerances(args, itr, decay_factors)
            set_cnf_options(args, atol, rtol, model)

            optimizer.zero_grad()
            if args.spectral_norm: spectral_norm_power_iteration(model, 1)
    x0=torch.randn(args.batch_size,1)-3+6*((torch.rand(args.batch_size,1))>0.5).float()
    x0 = cvt(x0)

    
    # x0val = toy_data.inf_train_gen(args.data, batch_size=args.val_batch_size)
    x0val=torch.randn(args.batch_size,1)-3+6*((torch.rand(args.batch_size,1))>0.5).float()
    x0val = cvt(x0val)

    log_msg = (
        '{:5s}  {:6s}   {:9s}  {:9s}  {:9s}  {:9s}      {:9s}  {:9s}  {:9s}  {:9s}  '.format(
            'iter', ' time','loss', 'L (L_2)', 'C (loss)', 'R (HJB)', 'valLoss', 'valL', 'valC', 'valR'
        )
    )
    logger.info(log_msg)

    time_meter = utils.AverageMeter()

    net.train()
    for itr in range(1, args.niters + 1):
        # train
        optim.zero_grad()
        loss, costs  = compute_loss(net, x0, nt=nt)
        loss.backward()
        optim.step()

        time_meter.update(time.time() - end)

        log_message = (
            '{:05d}  {:6.3f}   {:9.3e}  {:9.3e}  {:9.3e}  {:9.3e}  '.format(
                itr, time_meter.val , loss, costs[0], costs[1], costs[2]
            )
Пример #16
0
        logger.info(
            'must use --resume flag to provide the state_dict to evaluate')
        exit(1)

    logger.info(model)
    nWeights = count_parameters(model)
    logger.info("Number of trainable parameters: {}".format(nWeights))
    logger.info('Evaluating model on test set.')
    model.eval()

    override_divergence_fn(model, "brute_force")

    bInverse = True  # check one batch for inverse error, for speed

    with torch.no_grad():
        test_loss = utils.AverageMeter()
        test_nfe = utils.AverageMeter()
        for itr, x in enumerate(
                batch_iter(data.tst.x, batch_size=test_batch_size)):

            x = cvt(x)
            test_loss.update(compute_loss(x, model).item(), x.shape[0])
            test_nfe.update(count_nfe(model))

            if bInverse:  # check the ivnerse error
                z = model(x, reverse=False)  # push forward
                xpred = model(z, reverse=True)  # inverse
                logger.info('inverse norm for first batch: ')
                logger.info(torch.norm(xpred - x).item() / x.shape[0])
                bInverse = False
Пример #17
0
    if not cf.gpu:
        # assume debugging and run a subset
        nSamples = 1000
        testData = testData[:nSamples, :]
        normSamples = normSamples[:nSamples, :]
        if args.long_version:
            ffjordFx = ffjordFx[:nSamples, :]
            ffjordFinvfx = ffjordFinvfx[:nSamples, :]
            ffjordGen    = ffjordGen[:nSamples, :]

    net.eval()
    with torch.no_grad():

        # meters to hold testing results
        testLossMeter  = utils.AverageMeter()
        testAlphMeterL = utils.AverageMeter()
        testAlphMeterC = utils.AverageMeter()
        testAlphMeterR = utils.AverageMeter()

        # scale the GAS data set as it was in the training
        if args.data == 'gas':
            print(torch.min(testData),torch.max(testData))
            testData = testData / 5.0

        itr = 1
        for x0 in batch_iter(testData, batch_size=args.batch_size):

            x0 = cvt(x0)
            nex = x0.shape[0]
            test_loss, test_cs = compute_loss(net, x0, nt=nt_test)