示例#1
0
def train_classifier(networks, optimizers, dataloader, epoch=None, **options):
    for net in networks.values():
        net.train()
    netC = networks['classifier_kplusone']
    optimizerC = optimizers['classifier_kplusone']
    batch_size = options['batch_size']
    image_size = options['image_size']

    dataset_filename = options.get('aux_dataset')
    if not dataset_filename or not os.path.exists(dataset_filename):
        raise ValueError("Aux Dataset not available")
    print("Using aux_dataset {}".format(dataset_filename))
    aux_dataloader = FlexibleCustomDataloader(dataset_filename, batch_size=batch_size, image_size=image_size)

    for i, (images, class_labels) in enumerate(dataloader):
        images = Variable(images).cuda()
        labels = Variable(class_labels).cuda()

        ############################
        # Classifier Update
        ############################
        netC.zero_grad()

        # Classify real examples into the correct K classes
        classifier_logits = netC(images)
        augmented_logits = F.pad(classifier_logits, (0, 1))
        # _, labels_idx = labels.max(dim=1)
        labels_idx = labels
        errC = F.nll_loss(F.log_softmax(augmented_logits, dim=1), labels_idx)
        errC.backward()
        # log.collect('Classifier Loss', errC)

        # Classify aux_dataset examples as open set
        aux_images, aux_labels = aux_dataloader.get_batch()
        classifier_logits = netC(Variable(aux_images))
        augmented_logits = F.pad(classifier_logits, (0, 1))
        log_soft_open = F.log_softmax(augmented_logits, dim=1)[:, -1]
        errOpenSet = -log_soft_open.mean()
        errOpenSet.backward()
        # log.collect('Open Set Loss', errOpenSet)

        optimizerC.step()
        ############################

        # Keep track of accuracy on positive-labeled examples for monitoring
        # log.collect_prediction('Classifier Accuracy', netC(images), labels)

        # log.print_every()
    results = {
        'errC': errC.item(),
        'errOpenSet': errOpenSet.item(),
    }

    return results
parser.add_argument('--mode',
                    default='',
                    help='If set to "baseline" use the baseline classifier')

options = vars(parser.parse_args())

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from dataloader import CustomDataloader, FlexibleCustomDataloader
from training import train_classifier
from networks import build_networks, save_networks, get_optimizers
from options import load_options, get_current_epoch
from comparison import evaluate_with_comparison
from evaluation import save_evaluation

options = load_options(options)
dataloader = FlexibleCustomDataloader(fold='train', **options)
networks = build_networks(dataloader.num_classes, **options)
optimizers = get_optimizers(networks, finetune=True, **options)

eval_dataloader = CustomDataloader(last_batch=True,
                                   shuffle=False,
                                   fold='test',
                                   **options)

start_epoch = get_current_epoch(options['result_dir']) + 1
for epoch in range(start_epoch, start_epoch + options['epochs']):
    train_classifier(networks, optimizers, dataloader, epoch=epoch, **options)
    #print(networks['classifier_kplusone'])
    #weights = networks['classifier_kplusone'].fc1.weight
    eval_results = evaluate_with_comparison(networks, eval_dataloader,
                                            **options)
示例#3
0
def train_classifier(networks, optimizers, dataloader, epoch=None, **options):
    for net in networks.values():
        net.train()
    netD = networks['discriminator']
    optimizerD = optimizers['discriminator']
    result_dir = options['result_dir']
    batch_size = options['batch_size']
    image_size = options['image_size']
    latent_size = options['latent_size']

    # Hack: use a ground-truth dataset to test
    #dataset_filename = '/mnt/data/svhn-59.dataset'
    dataset_filename = os.path.join(options['result_dir'], 'aux_dataset.dataset')
    aux_dataloader = FlexibleCustomDataloader(dataset_filename, batch_size=batch_size, image_size=image_size)

    start_time = time.time()
    correct = 0
    total = 0

    for i, (images, class_labels) in enumerate(dataloader):
        images = Variable(images)
        labels = Variable(class_labels)

        ############################
        # Discriminator Updates
        ###########################
        netD.zero_grad()

        # Classify real examples into the correct K classes
        real_logits = netD(images)
        positive_labels = (labels == 1).type(torch.cuda.FloatTensor)
        augmented_logits = F.pad(real_logits, pad=(0,1))
        augmented_labels = F.pad(positive_labels, pad=(0,1))
        log_likelihood = F.log_softmax(augmented_logits, dim=1) * augmented_labels
        errC = -0.5 * log_likelihood.mean()

        # Classify the user-labeled (active learning) examples
        aux_images, aux_labels = aux_dataloader.get_batch()
        aux_images = Variable(aux_images)
        aux_labels = Variable(aux_labels)
        aux_logits = netD(aux_images)
        augmented_logits = F.pad(aux_logits, pad=(0,1))
        augmented_labels = F.pad(aux_labels, pad=(0, 1))
        augmented_positive_labels = (augmented_labels == 1).type(torch.FloatTensor).cuda()
        is_positive = (aux_labels.max(dim=1)[0] == 1).type(torch.FloatTensor).cuda()
        is_negative = 1 - is_positive
        fake_log_likelihood = F.log_softmax(augmented_logits, dim=1)[:,-1] * is_negative
        #real_log_likelihood = augmented_logits[:,-1].abs() * is_positive
        real_log_likelihood = (F.log_softmax(augmented_logits, dim=1) * augmented_positive_labels).sum(dim=1)
        errC -= fake_log_likelihood.mean() 
        errC -= 0.5 * real_log_likelihood.mean()

        errC.backward()
        optimizerD.step()
        ############################

        # Keep track of accuracy on positive-labeled examples for monitoring
        _, pred_idx = real_logits.max(1)
        _, label_idx = labels.max(1)
        correct += sum(pred_idx == label_idx).data.cpu().numpy()[0]
        total += len(labels)

        if i % 100 == 0:
            bps = (i+1) / (time.time() - start_time)
            ed = 0#errD.data[0]
            eg = 0#errG.data[0]
            ec = errC.data[0]
            acc = correct / max(total, 1)
            msg = '[{}][{}/{}] D:{:.3f} G:{:.3f} C:{:.3f} Acc. {:.3f} {:.3f} batch/sec'
            msg = msg.format(
                  epoch, i+1, len(dataloader),
                  ed, eg, ec, acc, bps)
            print(msg)
            print("Accuracy {}/{}".format(correct, total))
    return True
示例#4
0
def train_classifier(networks, optimizers, dataloader, epoch=None, **options):
    for net in networks.values():
        net.train()
    netC = networks['classifier_kplusone']
    optimizerC = optimizers['classifier_kplusone']
    batch_size = options['batch_size']
    image_size = options['image_size']

    dataset_filename = options.get('aux_dataset')
    if not dataset_filename or not os.path.exists(dataset_filename):
        raise ValueError("Aux Dataset not available")
    print("Using aux_dataset {}".format(dataset_filename))
    aux_dataloader = FlexibleCustomDataloader(dataset_filename,
                                              batch_size=batch_size,
                                              image_size=image_size)

    loss_class = losses.losses()

    for i, (images, class_labels) in enumerate(dataloader):
        images = Variable(images)
        # Following line FOR MNIST ONLY!!!!!!!! Remove otherwise
        #images = T.Pad(2).forward(images)
        labels = Variable(class_labels)

        ############################
        # Classifier Update
        ############################
        netC.zero_grad()

        # Classify real examples into the correct K classes
        #classifier_logits = netC(images)
        #augmented_logits = F.pad(classifier_logits, (0,1))
        #_, labels_idx = labels.max(dim=1)
        # TODO:: Replace with Matt's loss function ::
        #errC = F.nll_loss(F.log_softmax(augmented_logits, dim=1), labels_idx)
        #errC.backward()
        classifier_logits = netC(images)
        _, labels_idx = labels.max(dim=1)
        #errC = loss_class.kliep_loss(classifier_logits, labels_idx)
        errC = loss_class.power_loss_05(classifier_logits, labels_idx)
        errC.backward()

        log.collect('Classifier Loss', errC)

        # Classify aux_dataset examples as open set
        aux_images, aux_labels = aux_dataloader.get_batch()
        #classifier_logits = netC(Variable(aux_images))
        #augmented_logits = F.pad(classifier_logits, (0,1))
        #log_soft_open = F.log_softmax(augmented_logits, dim=1)[:, -1]
        #errOpenSet = -log_soft_open.mean()
        #errOpenSet.backward()
        classifier_logits = netC(Variable(aux_images))
        augmented_logits = F.pad(classifier_logits, (0, 1))
        target_label = Variable(torch.LongTensor(
            classifier_logits.shape[0])).cuda()
        target_label[:] = classifier_logits.shape[1]  #outputs.shape[1]
        #densityratio_loss = loss_class.kliep_loss(augmented_logits, target_label)
        densityratio_loss = loss_class.power_loss_05(augmented_logits,
                                                     target_label)
        densityratio_loss.backward()

        log.collect('Open Set Loss', densityratio_loss)

        optimizerC.step()
        ############################

        # Keep track of accuracy on positive-labeled examples for monitoring
        log.collect_prediction('Classifier Accuracy', netC(images), labels)

        log.print_every()

    return True
示例#5
0
文件: main.py 项目: kth0522/RNCl
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help="batch size")
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--fold',
                        '-f',
                        type=int,
                        default=0,
                        help='which fold you gonna train with')
    args = parser.parse_args()

    DATASET = 'tiny_imagenet-known-20-split'
    # MODEL = 'custom_classifier_9'
    MODEL = 'classifier32'
    fold_num = args.fold
    batch_size = args.batch_size
    is_train = True
    is_write = True

    start_time = datetime.datetime.now().strftime('%Y-%m-%d_%I-%M-%S-%p')
    runs = 'runs/{}-{}{}-{}'.format(MODEL, DATASET, fold_num, start_time)
    if is_write:
        writer = SummaryWriter(runs)

    closed_trainloader = FlexibleCustomDataloader(
        fold='train',
        batch_size=batch_size,
        dataset='./data/{}{}a.dataset'.format(DATASET, fold_num))
    closed_testloader = FlexibleCustomDataloader(
        fold='test',
        batch_size=batch_size,
        dataset='./data/{}{}a.dataset'.format(DATASET, fold_num))

    open_trainloader = FlexibleCustomDataloader(
        fold='train',
        batch_size=batch_size,
        dataset='./data/{}{}b.dataset'.format(DATASET, fold_num))
    open_testloader = FlexibleCustomDataloader(
        fold='test',
        batch_size=batch_size,
        dataset='./data/{}{}b.dataset'.format(DATASET, fold_num))

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    PATH = '{}/{}{}_custom_network_15'.format(runs, DATASET, fold_num)
    if is_train:
        net = classifier32()
        net.to(device)
        net.train()

        criterion = nn.CrossEntropyLoss()
        #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
        optimizer = optim.Adam(net.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[50, 100, 150, 200, 250, 300, 350, 400, 450],
            gamma=0.1)

        running_loss = 0.0
        for epoch in range(30):
            for i, (images, labels) in enumerate(closed_trainloader, 0):
                images = Variable(images)
                images = images.cuda()

                labels = Variable(labels)

                optimizer.zero_grad()

                # writer.add_graph(net, images)
                outputs = net(images)

                labels = torch.argmax(labels, dim=1)

                # writer.add_embedding(outputs, metadata=class_labels, label_img=images.unsqueeze(1))
                loss = criterion(outputs, labels)
                loss.backward()

                optimizer.step()
                #scheduler.step()

                running_loss += loss.item()
                if i % 100 == 99:
                    if is_write:
                        writer.add_scalar('training loss', running_loss / 100,
                                          epoch * len(closed_trainloader) + i)
                    current_time = datetime.datetime.now().strftime(
                        '%Y-%m-%d_%I-%M-%S-%p')
                    print(current_time)
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 100))

                    # writer.add_figure('predictions vs. actuals',
                    #                   plot_classes_preds(net, images, labels))
                    running_loss = 0.0
            if epoch % 50 == 49:
                torch.save(net.state_dict(),
                           "{}_{}.pth".format(PATH, epoch + 1))
            torch.save(net.state_dict(), "{}_latest.pth".format(PATH))

    test_net = classifier32()
    # PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/custom_classifier_14-tiny_imagenet-known-20-split0-2020-08-26_08-25-01-AM"
    # PATH = '{}/{}{}_custom_classifier_13'.format(PATH_1, DATASET, fold_num)
    test_net.load_state_dict(torch.load("{}_latest.pth".format(PATH)))
    test_net.to(device)

    closed_acc = evalute_classifier(test_net, closed_testloader)
    print("closed-set accuracy: ", closed_acc)
    auc_d = evaluate_openset(test_net, closed_testloader, open_testloader)
    print("auc discriminator: ", auc_d)

    result_file = '{}/{}{}.txt'.format(runs, DATASET, fold_num)

    current_time = datetime.datetime.now().strftime('%Y-%m-%d_%I-%M-%S-%p')

    if os.path.exists(result_file):
        f = open(result_file, 'a')
        f.write(current_time + "\n")
        f.write("{}{} \n".format(DATASET, fold_num))
        f.write("{} epoch".format(i))
        f.write("close-set accuracy: {} \n".format(closed_acc))
        f.write("AUROC: {} \n".format(auc_d))
        f.close()
    else:
        f = open(result_file, 'w')
        f.write(current_time + "\n")
        f.write("{}{} \n".format(DATASET, fold_num))
        f.write("{} epoch".format(i))
        f.write("close-set accuracy: {} \n".format(closed_acc))
        f.write("AUROC: {} \n".format(auc_d))
        f.close()
示例#6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help="batch size")
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--fold',
                        '-f',
                        type=int,
                        default=0,
                        help='which fold you gonna train with')
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--multi-eval', type=bool, default=False)
    parser.add_argument('--update-freq', type=int, default=1)
    args = parser.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    if args.seed is None:
        args.seed = np.random.randint(100000)

    print("seed: {}".format(args.seed))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed(args.seed)

    DATASET = 'tiny_imagenet-known-20-split'
    # MODEL = 'custom_classifier_9'
    MODEL = 'hybrid'
    fold_num = args.fold
    batch_size = args.batch_size
    is_train = False
    is_write = False

    start_time = datetime.datetime.now().strftime('%Y-%m-%d_%I-%M-%S-%p')
    runs = 'runs/{}-{}{}-{}'.format(MODEL, DATASET, fold_num, start_time)
    if is_write:
        writer = SummaryWriter(runs)

    closed_trainloader = FlexibleCustomDataloader(
        fold='train',
        batch_size=batch_size,
        dataset='./data/{}{}a.dataset'.format(DATASET, fold_num))
    closed_testloader = FlexibleCustomDataloader(
        fold='test',
        batch_size=batch_size,
        dataset='./data/{}{}a.dataset'.format(DATASET, fold_num))

    open_trainloader = FlexibleCustomDataloader(
        fold='train',
        batch_size=batch_size,
        dataset='./data/{}{}b.dataset'.format(DATASET, fold_num))
    open_testloader = FlexibleCustomDataloader(
        fold='test',
        batch_size=batch_size,
        dataset='./data/{}{}b.dataset'.format(DATASET, fold_num))

    batch_time = RunningAverageMeter(0.97)
    bpd_meter = RunningAverageMeter(0.97)
    logpz_meter = RunningAverageMeter(0.97)
    deltalogp_meter = RunningAverageMeter(0.97)
    firmom_meter = RunningAverageMeter(0.97)
    secmom_meter = RunningAverageMeter(0.97)
    gnorm_meter = RunningAverageMeter(0.97)
    ce_meter = RunningAverageMeter(0.97)

    PATH = '{}/{}{}_hybrid'.format(runs, DATASET, fold_num)
    if is_train:
        encoder = encoder32()
        encoder.to(device)
        encoder.train()

        flow = ResidualFlow(n_classes=20,
                            input_size=(64, 128, 4, 4),
                            n_blocks=[32, 32, 32],
                            intermediate_dim=512,
                            factor_out=False,
                            quadratic=False,
                            init_layer=None,
                            actnorm=True,
                            fc_actnorm=False,
                            dropout=0,
                            fc=False,
                            coeff=0.98,
                            vnorms='2222',
                            n_lipschitz_iters=None,
                            sn_atol=1e-3,
                            sn_rtol=1e-3,
                            n_power_series=None,
                            n_dist='poisson',
                            n_samples=1,
                            kernels='3-1-3',
                            activation_fn='swish',
                            fc_end=True,
                            n_exact_terms=2,
                            preact=True,
                            neumann_grad=True,
                            grad_in_forward=False,
                            first_resblock=True,
                            learn_p=False,
                            classification='hybrid',
                            classification_hdim=256,
                            block_type='resblock')
        flow.to(device)
        flow.train()

        classifier = classifier32()
        classifier.to(device)
        classifier.train()

        ema = ExponentialMovingAverage(flow)

        flow.train()

        criterion = nn.CrossEntropyLoss()
        # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
        optimizer = optim.Adam(encoder.parameters(), lr=0.0001)
        optimizer_2 = optim.Adam(flow.parameters(), lr=0.0001)
        optimizer_3 = optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9)
        # optimizer_3 = optim.Adam(classifier.parameters(), lr=0.0001)

        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
        #                                                  milestones=[50, 100, 150, 200, 250, 300, 350, 400, 450],
        #                                                  gamma=0.1)
        beta = 1
        running_loss = 0.0
        running_bpd = 0.0
        running_cls = 0.0
        best_loss = 1000
        tau = 100000
        for epoch in range(600):
            for i, (images, labels) in enumerate(closed_trainloader, 0):
                global_itr = epoch * len(closed_trainloader) + i
                images = Variable(images)
                images = images.cuda()

                labels = Variable(labels)

                # writer.add_graph(net, images)
                outputs = encoder(images)

                bpd, logits, logpz, neg_delta_logp = compute_loss(outputs,
                                                                  flow,
                                                                  beta=beta)
                cls_outputs = classifier(outputs)

                labels = torch.argmax(labels, dim=1)
                cls_loss = criterion(cls_outputs, labels)

                firmom, secmom = estimator_moments(flow)

                bpd_meter.update(bpd.item())
                logpz_meter.update(logpz.item())
                deltalogp_meter.update(neg_delta_logp.item())
                firmom_meter.update(firmom)
                secmom_meter.update(secmom)

                loss = bpd + cls_loss
                #
                # loss.backward()
                #
                # labels = torch.argmax(labels, dim=1)
                #
                # # writer.add_embedding(outputs, metadata=class_labels, label_img=images.unsqueeze(1))
                # loss = criterion(outputs, labels)
                loss.backward()

                if global_itr % args.update_freq == args.update_freq - 1:
                    if args.update_freq > 1:
                        with torch.no_grad():
                            for p in flow.parameters():
                                if p.grad is not None:
                                    p.grad /= args.update_freq

                    grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
                        flow.parameters(), 1.)

                    optimizer.step()
                    optimizer_2.step()
                    optimizer_3.step()

                    optimizer.zero_grad()
                    optimizer_2.zero_grad()
                    optimizer_3.zero_grad()

                    update_lipschitz(flow)
                    ema.apply()
                    gnorm_meter.update(grad_norm)

                running_bpd += bpd.item()
                running_cls += cls_loss.item()
                running_loss += loss.item()

                if i % 100 == 99:
                    if is_write:
                        writer.add_scalar('bits per dimension',
                                          running_bpd / 100, global_itr)
                        writer.add_scalar('classification loss',
                                          running_cls / 100, global_itr)
                        writer.add_scalar('total loss', running_loss / 100,
                                          global_itr)
                    current_time = datetime.datetime.now().strftime(
                        '%Y-%m-%d_%I-%M-%S-%p')
                    print(current_time)
                    print(
                        '[%d, %5d] bpd: %.3f, cls_loss: %.3f, total_loss: %.3f'
                        % (epoch + 1, i + 1, running_bpd / 100,
                           running_cls / 100, running_loss / 100))
                    if epoch > 1 and running_loss / 100 < best_loss:
                        best_loss = running_loss / 100
                        print("best loss updated! :", best_loss)
                        torch.save(
                            {
                                'state_dict': flow.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'args': args,
                                'ema': ema,
                            }, "{}_flow_best.pth".format(PATH))

                        torch.save(encoder.state_dict(),
                                   "{}_encoder_best.pth".format(PATH))
                        torch.save(classifier.state_dict(),
                                   "{}_classifier_best.pth".format(PATH))

                    # writer.add_figure('predictions vs. actuals',
                    #                   plot_classes_preds(net, images, labels))
                    running_loss = 0.0
                    running_bpd = 0.0
                    running_cls = 0.0

                del images
                torch.cuda.empty_cache()
                gc.collect()

            if epoch % 50 == 49:
                torch.save(
                    {
                        'state_dict': flow.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'args': args,
                        'ema': ema,
                    }, "{}_flow_{}.pth".format(PATH, epoch + 1))

                torch.save(encoder.state_dict(),
                           "{}_encoder_{}.pth".format(PATH, epoch + 1))
                torch.save(classifier.state_dict(),
                           "{}_classifier_{}.pth".format(PATH, epoch + 1))

    PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/hybrid-tiny_imagenet-known-20-split0-2020-09-21_05-49-50-PM"
    PATH = "{}/{}{}_hybrid".format(PATH_1, DATASET, fold_num)

    if args.multi_eval:
        for i in range(50, 550, 50):
            test_encoder = encoder32()
            test_encoder.to(device)
            test_encoder.load_state_dict(
                torch.load("{}_encoder_{}.pth".format(PATH, i)))
            # state_dict = torch.load("{}_encoder_{}.pth".format(PATH, i))
            # # create new OrderedDict that does not contain `module.`
            #
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = k[7:]  # remove `module.`
            #     new_state_dict[name] = v
            # # load params
            # test_encoder.load_state_dict(new_state_dict)

            test_classifier = classifier32()
            test_classifier.to(device)
            # state_dict = torch.load("{}_classifier_{}.pth".format(PATH, i))
            # # create new OrderedDict that does not contain `module.`
            #
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = k[7:]  # remove `module.`
            #     new_state_dict[name] = v
            # # load params
            # test_classifier.load_state_dict(new_state_dict)
            test_classifier.load_state_dict(
                torch.load("{}_classifier_{}.pth".format(PATH, i)))

            test_flow = ResidualFlow(n_classes=20,
                                     input_size=(64, 128, 4, 4),
                                     n_blocks=[32, 32, 32],
                                     intermediate_dim=512,
                                     factor_out=False,
                                     quadratic=False,
                                     init_layer=None,
                                     actnorm=True,
                                     fc_actnorm=False,
                                     dropout=0,
                                     fc=False,
                                     coeff=0.98,
                                     vnorms='2222',
                                     n_lipschitz_iters=None,
                                     sn_atol=1e-3,
                                     sn_rtol=1e-3,
                                     n_power_series=None,
                                     n_dist='poisson',
                                     n_samples=1,
                                     kernels='3-1-3',
                                     activation_fn='swish',
                                     fc_end=True,
                                     n_exact_terms=2,
                                     preact=True,
                                     neumann_grad=True,
                                     grad_in_forward=False,
                                     first_resblock=True,
                                     learn_p=False,
                                     classification='hybrid',
                                     classification_hdim=256,
                                     block_type='resblock')

            test_flow.to(device)

            with torch.no_grad():
                x = torch.rand(1, *input_size[1:]).to(device)
                test_flow(x)
            checkpt = torch.load("{}_flow_{}.pth".format(PATH, i))
            sd = {
                k: v
                for k, v in checkpt['state_dict'].items()
                if 'last_n_samples' not in k
            }
            state = test_flow.state_dict()
            state.update(sd)
            test_flow.load_state_dict(state, strict=True)
            # test_ema.set(checkpt['ema'])

            hybrid = HybridModel(test_encoder, test_classifier, test_flow)

            closed_acc = evalute_classifier(hybrid, closed_testloader)
            print("closed-set accuracy: ", closed_acc)
            auc_d = evaluate_openset(hybrid, closed_testloader,
                                     open_testloader)
            print("auc discriminator: ", auc_d)

            result_file = '{}/{}{}.txt'.format(runs, DATASET, fold_num)

            current_time = datetime.datetime.now().strftime(
                '%Y-%m-%d_%I-%M-%S-%p')

            if is_write:
                if os.path.exists(result_file):
                    f = open(result_file, 'a')
                    f.write(current_time + "\n")
                    f.write("seed: {}\n".format(args.seed))
                    f.write("{}{} \n".format(DATASET, fold_num))
                    f.write("{} epoch".format(i))
                    f.write("close-set accuracy: {} \n".format(closed_acc))
                    f.write("AUROC: {} \n".format(auc_d))
                    f.close()
                else:
                    f = open(result_file, 'w')
                    f.write(current_time + "\n")
                    f.write("seed: {}\n".format(args.seed))
                    f.write("{}{} \n".format(DATASET, fold_num))
                    f.write("{} epoch".format(i))
                    f.write("close-set accuracy: {} \n".format(closed_acc))
                    f.write("AUROC: {} \n".format(auc_d))
                    f.close()
    else:
        PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/hybrid-tiny_imagenet-known-20-split0-2020-09-21_05-49-50-PM"
        PATH = "{}/{}{}_hybrid".format(PATH_1, DATASET, fold_num)

        test_encoder = encoder32()
        test_encoder.to(device)
        test_encoder.load_state_dict(
            torch.load("{}_encoder_latest.pth".format(PATH)))

        test_classifier = classifier32()
        test_classifier.to(device)
        test_classifier.load_state_dict(
            torch.load("{}_classifier_latest.pth".format(PATH)))

        test_flow = ResidualFlow(n_classes=20,
                                 input_size=(64, 128, 4, 4),
                                 n_blocks=[32, 32, 32],
                                 intermediate_dim=512,
                                 factor_out=False,
                                 quadratic=False,
                                 init_layer=None,
                                 actnorm=True,
                                 fc_actnorm=False,
                                 dropout=0,
                                 fc=False,
                                 coeff=0.98,
                                 vnorms='2222',
                                 n_lipschitz_iters=None,
                                 sn_atol=1e-3,
                                 sn_rtol=1e-3,
                                 n_power_series=None,
                                 n_dist='poisson',
                                 n_samples=1,
                                 kernels='3-1-3',
                                 activation_fn='swish',
                                 fc_end=True,
                                 n_exact_terms=2,
                                 preact=True,
                                 neumann_grad=True,
                                 grad_in_forward=False,
                                 first_resblock=True,
                                 learn_p=False,
                                 classification='hybrid',
                                 classification_hdim=256,
                                 block_type='resblock')

        test_flow.to(device)

        with torch.no_grad():
            x = torch.rand(1, *input_size[1:]).to(device)
            test_flow(x)
        checkpt = torch.load("{}_flow_latest.pth".format(PATH))
        sd = {
            k: v
            for k, v in checkpt['state_dict'].items()
            if 'last_n_samples' not in k
        }
        state = test_flow.state_dict()
        state.update(sd)
        test_flow.load_state_dict(state, strict=True)

        hybrid = HybridModel(test_encoder, test_classifier, test_flow)

        closed_acc = evalute_classifier(hybrid, closed_testloader)
        print("closed-set accuracy: ", closed_acc)
        auc_d = evaluate_openset(hybrid, closed_testloader, open_testloader)
        print("auc discriminator: ", auc_d)