Ejemplo n.º 1
0
                    v for v in func.trainable_variables if 'bias' not in v.name
                ]
                l2_loss = tf.add_n(
                    [tf.reduce_sum(tf.math.square(v)) for v in weights]) * 1e-6
                loss = loss + l2_loss

            nfe = func.nfe.numpy()
            func.nfe.assign(0.)
            grads = tape.gradient(loss, func.trainable_variables)
            nbe = func.nfe.numpy()
            func.nfe.assign(0.)
            print('NFE: {}, NBE: {}'.format(nfe, nbe))

            grad_vars = zip(grads, func.trainable_variables)
            optimizer.apply_gradients(grad_vars)
            time_meter.update(time.time() - end)
            loss_meter.update(loss.numpy())
            if itr % args.test_freq == 0:
                pred_x_extrap = odeint(func, x_val_extrap[0], t)
                pred_x_interp = odeint(func, x_val_interp[0], t)
                loss_extrap = tf.reduce_mean(
                    tf.abs(pred_x_extrap - x_val_extrap))
                loss_interp = tf.reduce_mean(
                    tf.abs(pred_x_interp - x_val_interp))
                print('Iter {:04d} | Traj. Loss ex.: {:.6f} | '
                      'Traj. Loss in.: {:.6f} | Seconds/batch {:,.4f}'.format(
                          itr, loss_extrap.numpy(), loss_interp.numpy(),
                          time_meter.avg))
                visualize(func,
                          np.array(x_val),
                          PLOT_DIR,
Ejemplo n.º 2
0
    def train(self, number_of_epochs):
        model = self.net
        device = self.device
        name = self.name
        writer = SummaryWriter(log_dir='experiments/' + str(name))
        makedirs(os.path.join(os.getcwd(), "experiments", name))

        model = model.float().to(device)
        criterion = self.criterion

        train_loader, test_loader = self.train_dataloader, self.test_dataloader
        data_gen = inf_generator(train_loader)
        batches_per_epoch = len(train_loader)

        lr_fn = self.scheduler

        optimizer = self.optimizer(model.parameters())

        best_leading_metric = 0.0
        batch_time_meter = RunningAverageMeter()
        loss_meter = RunningAverageMeter()
        if self.nfe_logging:
            f_nfe_meter = RunningAverageMeter()
            b_nfe_meter = RunningAverageMeter()
        end = time.time()

        for itr in tqdm(range(number_of_epochs * batches_per_epoch)):
            if lr_fn is not None:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_fn(itr)

            model.train()
            optimizer.zero_grad()
            dct = data_gen.__next__()
            # print(dct)
            model_input = [
                dct[key].float().to(device) for key in dct.keys()
                if "data" in key
            ]
            y = dct["label"].to(device)
            # x = x.unsqueeze(1)
            model_input = self.input_preprocessing(model_input)
            logits = model(*model_input)
            logits = self.loss_preprocessing(logits)
            loss = criterion(logits, y)

            if self.nfe_logging:
                nfe_forward = model.feature_layers[0].nfe
                model.feature_layers[0].nfe = 0

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            if itr % 10 == 9:
                writer.add_scalar("Loss/train", loss_meter.val, itr)
            if self.nfe_logging:
                nfe_backward = model.feature_layers[0].nfe
                model.feature_layers[0].nfe = 0
                f_nfe_meter.update(nfe_forward)
                b_nfe_meter.update(nfe_backward)

            batch_time_meter.update(time.time() - end)
            end = time.time()

            if itr % batches_per_epoch == 0:
                with torch.no_grad():
                    model.eval()
                    preds = []
                    labs = []
                    for i, data in enumerate(test_loader, 0):
                        model_input = [
                            data[key].float().to(device)
                            for key in data.keys() if "data" in key
                        ]
                        labels = data['label'].to(device)
                        model_input = self.input_preprocessing(model_input)
                        outputs = model(*model_input)
                        predicted = self.output_to_pred_fcn(outputs)
                        # if len(predicted.tolist()) != len(labels.tolist()):
                        #     print(len(model_input[0].tolist()))
                        #     print(len(labels.tolist()))
                        preds += predicted.tolist()
                        labs += labels.tolist()
                    for metric in self.metrics.keys():
                        metric_val = self.metrics[metric](labs, preds)
                        if metric == self.leading_metric and metric_val > best_leading_metric:
                            best_leading_metric = metric_val
                            torch.save({'state_dict': model.state_dict()},
                                       os.path.join(os.getcwd(), "experiments",
                                                    name, 'model_best.pth'))
                        writer.add_scalar(metric + "/test", metric_val,
                                          itr // batches_per_epoch)
                    if self.nfe_logging:
                        writer.add_scalar("NFE-F", f_nfe_meter.val,
                                          itr // batches_per_epoch)
                        writer.add_scalar("NFE-B", b_nfe_meter.val,
                                          itr // batches_per_epoch)

        labs = []
        preds = []
        for data in test_loader:
            with torch.no_grad():
                model.eval()
                model_input = [
                    data[key].float().to(device) for key in data.keys()
                    if "data" in key
                ]
                labels = data['label'].to(device)
                model_input = self.input_preprocessing(model_input)
                outputs = model(*model_input)
                predicted = self.output_to_pred_fcn(outputs)
                preds += predicted.tolist()
                labs += labels.tolist()

        torch.save({'state_dict': model.state_dict()},
                   os.path.join(os.getcwd(), "experiments", name,
                                'model_last.pth'))
        try:
            labs = [self.class_dict[a] for a in labs]
            preds = [self.class_dict[a] for a in preds]
            other_metrics = [[metric, [self.metrics[metric](labs, preds)]]
                             for metric in self.metrics
                             if metric is not self.leading_metric]
            writer.add_figure(
                name + " - Confusion Matrix",
                plot_confusion_matrix(
                    labs, preds,
                    [self.class_dict[key] for key in self.class_dict.keys()]))
        except:
            other_metrics = []
            pass
        writer.close()
        return [self.leading_metric, [best_leading_metric]], other_metrics
Ejemplo n.º 3
0
    val_X = from_numpy(val_X).float().to(device)
    val_T = from_numpy(val_T).float().to(device)

    for itr in range(1, n_iters + 1):
        # training
        optimizer.zero_grad()
        loss = 0
        for batch_idx in range(training_T.shape[0]):
            y_hat, _, _, _ = latentODE(training_X[batch_idx, :, :seq_len, :],
                                       training_T[batch_idx, :])
            y = training_X[batch_idx, :, :, 3:4]
            loss += metric(y_hat, y)

        loss.backward()
        optimizer.step()
        loss_meter.update(loss.item() / training_T.shape[0])
        loss_meter.print()

        # validation
        with no_grad():
            # cals validation loss
            if itr % 10 == 0:
                val_loss = 0
                for val_idx in range(val_T.shape[0]):
                    y_hat, _, _, _ = latentODE(val_X[val_idx, :, :seq_len, :],
                                               val_T[val_idx, :])
                    y = val_X[val_idx, :, :, 3:4]
                    val_loss += metric(y_hat, y)
                val_loss_meter.update(val_loss.item() / val_T.shape[0])
                val_loss_meter.print()
def trainer(model,
            logger,
            loader,
            args,
            data="mnist",
            optimizer=None,
            scheduler=None,
            adv_train=None,
            tboard=True,
            **kwargs):
    logger.info("=" * 80)
    logger.info("Train Info")
    logger.info("Model : {}".format(args.model))
    logger.info("Number of blocks : {}".format(args.block))
    logger.info("Number of parameters : {}".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    start_time = time.time()
    best_acc = 0.
    device = args.device

    try:
        criterion = model.loss().to(args.device)
    except:
        criterion = nn.NLLLoss().to(args.device)
    logger.info("Criterion : {}".format(criterion.__class__.__name__))
    logger.info("Adversarial Training : {}".format(adv_train))
    logger.info("=" * 80)
    data_gen = inf_generator(loader['train_loader'])
    batches_per_epoch = len(loader['train_loader'])

    writer = SummaryWriter(log_dir=os.path.join(
        args.save, args.model + "_" + str(args.block) + "_" + str(adv_train)))
    if data == "mnist":
        dummy_input = torch.rand(1, 1, 28, 28).to(device)
    else:
        dummy_input = torch.rand(1, 3, 32, 32).to(device)
    writer.add_graph(model, dummy_input)

    best_acc = 0.
    best_loss = 1000.
    best_acc_epoch = 0
    best_loss_epoch = 0
    batch_time_meter = RunningAverageMeter()
    end_time = time.time()
    if args.hist:
        hist_dict = dict()
    args.alpha /= 255
    adv, stats = adv_train_module(adv_train, model, data, args.iters,
                                  args.device, args.alpha, args.repeat)

    torch.save({
        "state_dict": model.state_dict(),
        "args": args
    }, os.path.join(args.save, "model_acc.pt"))
    torch.save({
        "state_dict": model.state_dict(),
        "args": args
    }, os.path.join(args.save, "model_loss.pt"))
    for itr in tqdm.tqdm(range(args.epochs * batches_per_epoch)):

        if itr % batches_per_epoch == 0 and scheduler is not None:
            scheduler.step()

        model.train()
        optimizer.zero_grad()
        x, y = data_gen.__next__()
        x = x.to(args.device)
        y = y.to(args.device)

        if adv_train is not None:
            x = adv.perturb(x, y, device=args.device)
            if adv_train == "ball":
                y = torch.cat([y for _ in range(stats[0])])
        model.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)

        loss.backward()
        optimizer.step()

        batch_time_meter.update(time.time() - end_time)
        end_time = time.time()
        writer.add_scalar("train_loss", loss.cpu().detach(), itr)

        if itr % batches_per_epoch == 0:
            image = adv.inverse_normalize(x.cpu())
            image = torchvision.utils.make_grid(image, scale_each=False)
            writer.add_image("train_image", image,
                             int(itr // batches_per_epoch))
            model.eval()
            with torch.no_grad():
                train_acc, train_loss = accuracy(
                    model,
                    dataset_loader=loader['train_eval_loader'],
                    device=args.device,
                    criterion=criterion)
                val_acc, val_loss = accuracy(
                    model,
                    dataset_loader=loader['test_loader'],
                    device=args.device,
                    criterion=criterion)
                if val_acc >= best_acc:
                    torch.save({
                        "state_dict": model.state_dict(),
                        "args": args
                    }, os.path.join(args.save, "model_acc.pt"))
                    best_acc = val_acc
                    best_acc_epoch = int(itr // batches_per_epoch)
                if val_loss <= best_loss:
                    torch.save({
                        "state_dict": model.state_dict(),
                        "args": args
                    }, os.path.join(args.save, "model_loss.pt"))
                    best_loss = val_loss
                    best_loss_epoch = int(itr // batches_per_epoch)
                writer.add_scalar("train_loss_epoch", train_loss,
                                  int(itr // batches_per_epoch))
                writer.add_scalar("train_acc", train_acc,
                                  int(itr // batches_per_epoch))
                writer.add_scalar("validation_loss_epoch", val_loss,
                                  int(itr // batches_per_epoch))
                writer.add_scalar("validation_acc", val_acc,
                                  int(itr // batches_per_epoch))
                logger.info(
                    "Epoch {:03d} | Time {:.3f} ({:.3f}) | Train loss {:.4f} | Validation loss {:.4f} | Train Acc {:.4f} | Validation Acc {:.4f}"
                    .format(int(itr // batches_per_epoch),
                            batch_time_meter.val, batch_time_meter.avg,
                            train_loss, val_loss, train_acc, val_acc))
            torch.save({
                "state_dict": model.state_dict(),
                "args": args
            }, os.path.join(args.save, "model_final.pt"))

    torch.save({
        "state_dict": model.state_dict(),
        "args": args
    }, os.path.join(args.save, "model_final.pt"))
    if args.hist:
        with open(os.path.join(args.save, "history.json"), "w") as f:
            json.dump(hist_dict, f)

    logger.info("=" * 80)
    logger.info("Required Time : {:03d} minute {:.2f} seconds".format(
        int((time.time() - start_time) // 60),
        (time.time() - start_time) % 60))
    logger.info("Best Acc Epoch : {:03d}".format(best_acc_epoch))
    logger.info("Best Validation Accuracy : {:.4f}".format(best_acc))
    logger.info("Best loss Epoch : {:03d}".format(best_loss_epoch))
    logger.info("Best Validation loss : {:.4f}".format(best_loss))
    logger.info("Train end")
    logger.info("=" * 80)
    writer.close()

    return model
Ejemplo n.º 5
0
    def train(self, number_of_epochs):
        model = self.encoder_net
        device = self.device
        name = self.name
        decoder = self.decoder_net
        writer = SummaryWriter(log_dir='experiments/' + str(name))
        makedirs(os.path.join(os.getcwd(), "experiments", name))

        model = model.float().to(device)
        decoder = decoder.float().to(device)
        criterion = self.criterion

        train_loader = self.train_dataloader
        data_gen = inf_generator(train_loader)
        batches_per_epoch = len(train_loader)

        lr_fn = self.scheduler
        params = list(model.parameters()) + list(decoder.parameters())
        optimizer = self.optimizer(params)

        batch_time_meter = RunningAverageMeter()
        loss_meter = RunningAverageMeter()
        end = time.time()

        for itr in tqdm(range(number_of_epochs * batches_per_epoch)):
            # print("start: {}".format(time.time() - end))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_fn(itr)
            # print("params_set: {}".format(time.time() - end))
            model.train()
            decoder.train()
            optimizer.zero_grad()
            # print("models_working: {}".format(time.time() - end))
            dct = data_gen.__next__()
            model_input = [
                torch.reshape(dct[key].float(), (-1, 180)).to(device)
                for key in dct.keys() if "data" in key
            ]
            # print("data_generated: {}".format(time.time() - end))
            embedding = model(*model_input)
            logits = decoder(embedding)
            logits = self.loss_preprocessing(logits) * 100
            model_input = torch.stack(model_input, dim=-1) * 100
            # print("predictions_made: {}".format(time.time() - end))
            loss = criterion(logits, model_input)
            # print("loss_counted: {}".format(time.time() - end))
            loss.backward()
            optimizer.step()
            # print("optimizer_step: {}".format(time.time() - end))
            loss_meter.update(loss.item())
            if itr % 10 == 0:
                writer.add_scalar("Loss/train", loss_meter.val, itr)
                torch.save({'state_dict': model.state_dict()},
                           os.path.join(os.getcwd(), "experiments", name,
                                        'model_final.pth'))
            batch_time_meter.update(time.time() - end)
            writer.add_scalar("batch_time/train", batch_time_meter.val, itr)
            end = time.time()

        writer.close()
        return ["Loss", loss_meter.val], [None]
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help="batch size")
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--fold',
                        '-f',
                        type=int,
                        default=0,
                        help='which fold you gonna train with')
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--multi-eval', type=bool, default=False)
    parser.add_argument('--update-freq', type=int, default=1)
    args = parser.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    if args.seed is None:
        args.seed = np.random.randint(100000)

    print("seed: {}".format(args.seed))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed(args.seed)

    DATASET = 'tiny_imagenet-known-20-split'
    # MODEL = 'custom_classifier_9'
    MODEL = 'hybrid'
    fold_num = args.fold
    batch_size = args.batch_size
    is_train = False
    is_write = False

    start_time = datetime.datetime.now().strftime('%Y-%m-%d_%I-%M-%S-%p')
    runs = 'runs/{}-{}{}-{}'.format(MODEL, DATASET, fold_num, start_time)
    if is_write:
        writer = SummaryWriter(runs)

    closed_trainloader = FlexibleCustomDataloader(
        fold='train',
        batch_size=batch_size,
        dataset='./data/{}{}a.dataset'.format(DATASET, fold_num))
    closed_testloader = FlexibleCustomDataloader(
        fold='test',
        batch_size=batch_size,
        dataset='./data/{}{}a.dataset'.format(DATASET, fold_num))

    open_trainloader = FlexibleCustomDataloader(
        fold='train',
        batch_size=batch_size,
        dataset='./data/{}{}b.dataset'.format(DATASET, fold_num))
    open_testloader = FlexibleCustomDataloader(
        fold='test',
        batch_size=batch_size,
        dataset='./data/{}{}b.dataset'.format(DATASET, fold_num))

    batch_time = RunningAverageMeter(0.97)
    bpd_meter = RunningAverageMeter(0.97)
    logpz_meter = RunningAverageMeter(0.97)
    deltalogp_meter = RunningAverageMeter(0.97)
    firmom_meter = RunningAverageMeter(0.97)
    secmom_meter = RunningAverageMeter(0.97)
    gnorm_meter = RunningAverageMeter(0.97)
    ce_meter = RunningAverageMeter(0.97)

    PATH = '{}/{}{}_hybrid'.format(runs, DATASET, fold_num)
    if is_train:
        encoder = encoder32()
        encoder.to(device)
        encoder.train()

        flow = ResidualFlow(n_classes=20,
                            input_size=(64, 128, 4, 4),
                            n_blocks=[32, 32, 32],
                            intermediate_dim=512,
                            factor_out=False,
                            quadratic=False,
                            init_layer=None,
                            actnorm=True,
                            fc_actnorm=False,
                            dropout=0,
                            fc=False,
                            coeff=0.98,
                            vnorms='2222',
                            n_lipschitz_iters=None,
                            sn_atol=1e-3,
                            sn_rtol=1e-3,
                            n_power_series=None,
                            n_dist='poisson',
                            n_samples=1,
                            kernels='3-1-3',
                            activation_fn='swish',
                            fc_end=True,
                            n_exact_terms=2,
                            preact=True,
                            neumann_grad=True,
                            grad_in_forward=False,
                            first_resblock=True,
                            learn_p=False,
                            classification='hybrid',
                            classification_hdim=256,
                            block_type='resblock')
        flow.to(device)
        flow.train()

        classifier = classifier32()
        classifier.to(device)
        classifier.train()

        ema = ExponentialMovingAverage(flow)

        flow.train()

        criterion = nn.CrossEntropyLoss()
        # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
        optimizer = optim.Adam(encoder.parameters(), lr=0.0001)
        optimizer_2 = optim.Adam(flow.parameters(), lr=0.0001)
        optimizer_3 = optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9)
        # optimizer_3 = optim.Adam(classifier.parameters(), lr=0.0001)

        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
        #                                                  milestones=[50, 100, 150, 200, 250, 300, 350, 400, 450],
        #                                                  gamma=0.1)
        beta = 1
        running_loss = 0.0
        running_bpd = 0.0
        running_cls = 0.0
        best_loss = 1000
        tau = 100000
        for epoch in range(600):
            for i, (images, labels) in enumerate(closed_trainloader, 0):
                global_itr = epoch * len(closed_trainloader) + i
                images = Variable(images)
                images = images.cuda()

                labels = Variable(labels)

                # writer.add_graph(net, images)
                outputs = encoder(images)

                bpd, logits, logpz, neg_delta_logp = compute_loss(outputs,
                                                                  flow,
                                                                  beta=beta)
                cls_outputs = classifier(outputs)

                labels = torch.argmax(labels, dim=1)
                cls_loss = criterion(cls_outputs, labels)

                firmom, secmom = estimator_moments(flow)

                bpd_meter.update(bpd.item())
                logpz_meter.update(logpz.item())
                deltalogp_meter.update(neg_delta_logp.item())
                firmom_meter.update(firmom)
                secmom_meter.update(secmom)

                loss = bpd + cls_loss
                #
                # loss.backward()
                #
                # labels = torch.argmax(labels, dim=1)
                #
                # # writer.add_embedding(outputs, metadata=class_labels, label_img=images.unsqueeze(1))
                # loss = criterion(outputs, labels)
                loss.backward()

                if global_itr % args.update_freq == args.update_freq - 1:
                    if args.update_freq > 1:
                        with torch.no_grad():
                            for p in flow.parameters():
                                if p.grad is not None:
                                    p.grad /= args.update_freq

                    grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
                        flow.parameters(), 1.)

                    optimizer.step()
                    optimizer_2.step()
                    optimizer_3.step()

                    optimizer.zero_grad()
                    optimizer_2.zero_grad()
                    optimizer_3.zero_grad()

                    update_lipschitz(flow)
                    ema.apply()
                    gnorm_meter.update(grad_norm)

                running_bpd += bpd.item()
                running_cls += cls_loss.item()
                running_loss += loss.item()

                if i % 100 == 99:
                    if is_write:
                        writer.add_scalar('bits per dimension',
                                          running_bpd / 100, global_itr)
                        writer.add_scalar('classification loss',
                                          running_cls / 100, global_itr)
                        writer.add_scalar('total loss', running_loss / 100,
                                          global_itr)
                    current_time = datetime.datetime.now().strftime(
                        '%Y-%m-%d_%I-%M-%S-%p')
                    print(current_time)
                    print(
                        '[%d, %5d] bpd: %.3f, cls_loss: %.3f, total_loss: %.3f'
                        % (epoch + 1, i + 1, running_bpd / 100,
                           running_cls / 100, running_loss / 100))
                    if epoch > 1 and running_loss / 100 < best_loss:
                        best_loss = running_loss / 100
                        print("best loss updated! :", best_loss)
                        torch.save(
                            {
                                'state_dict': flow.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'args': args,
                                'ema': ema,
                            }, "{}_flow_best.pth".format(PATH))

                        torch.save(encoder.state_dict(),
                                   "{}_encoder_best.pth".format(PATH))
                        torch.save(classifier.state_dict(),
                                   "{}_classifier_best.pth".format(PATH))

                    # writer.add_figure('predictions vs. actuals',
                    #                   plot_classes_preds(net, images, labels))
                    running_loss = 0.0
                    running_bpd = 0.0
                    running_cls = 0.0

                del images
                torch.cuda.empty_cache()
                gc.collect()

            if epoch % 50 == 49:
                torch.save(
                    {
                        'state_dict': flow.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'args': args,
                        'ema': ema,
                    }, "{}_flow_{}.pth".format(PATH, epoch + 1))

                torch.save(encoder.state_dict(),
                           "{}_encoder_{}.pth".format(PATH, epoch + 1))
                torch.save(classifier.state_dict(),
                           "{}_classifier_{}.pth".format(PATH, epoch + 1))

    PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/hybrid-tiny_imagenet-known-20-split0-2020-09-21_05-49-50-PM"
    PATH = "{}/{}{}_hybrid".format(PATH_1, DATASET, fold_num)

    if args.multi_eval:
        for i in range(50, 550, 50):
            test_encoder = encoder32()
            test_encoder.to(device)
            test_encoder.load_state_dict(
                torch.load("{}_encoder_{}.pth".format(PATH, i)))
            # state_dict = torch.load("{}_encoder_{}.pth".format(PATH, i))
            # # create new OrderedDict that does not contain `module.`
            #
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = k[7:]  # remove `module.`
            #     new_state_dict[name] = v
            # # load params
            # test_encoder.load_state_dict(new_state_dict)

            test_classifier = classifier32()
            test_classifier.to(device)
            # state_dict = torch.load("{}_classifier_{}.pth".format(PATH, i))
            # # create new OrderedDict that does not contain `module.`
            #
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = k[7:]  # remove `module.`
            #     new_state_dict[name] = v
            # # load params
            # test_classifier.load_state_dict(new_state_dict)
            test_classifier.load_state_dict(
                torch.load("{}_classifier_{}.pth".format(PATH, i)))

            test_flow = ResidualFlow(n_classes=20,
                                     input_size=(64, 128, 4, 4),
                                     n_blocks=[32, 32, 32],
                                     intermediate_dim=512,
                                     factor_out=False,
                                     quadratic=False,
                                     init_layer=None,
                                     actnorm=True,
                                     fc_actnorm=False,
                                     dropout=0,
                                     fc=False,
                                     coeff=0.98,
                                     vnorms='2222',
                                     n_lipschitz_iters=None,
                                     sn_atol=1e-3,
                                     sn_rtol=1e-3,
                                     n_power_series=None,
                                     n_dist='poisson',
                                     n_samples=1,
                                     kernels='3-1-3',
                                     activation_fn='swish',
                                     fc_end=True,
                                     n_exact_terms=2,
                                     preact=True,
                                     neumann_grad=True,
                                     grad_in_forward=False,
                                     first_resblock=True,
                                     learn_p=False,
                                     classification='hybrid',
                                     classification_hdim=256,
                                     block_type='resblock')

            test_flow.to(device)

            with torch.no_grad():
                x = torch.rand(1, *input_size[1:]).to(device)
                test_flow(x)
            checkpt = torch.load("{}_flow_{}.pth".format(PATH, i))
            sd = {
                k: v
                for k, v in checkpt['state_dict'].items()
                if 'last_n_samples' not in k
            }
            state = test_flow.state_dict()
            state.update(sd)
            test_flow.load_state_dict(state, strict=True)
            # test_ema.set(checkpt['ema'])

            hybrid = HybridModel(test_encoder, test_classifier, test_flow)

            closed_acc = evalute_classifier(hybrid, closed_testloader)
            print("closed-set accuracy: ", closed_acc)
            auc_d = evaluate_openset(hybrid, closed_testloader,
                                     open_testloader)
            print("auc discriminator: ", auc_d)

            result_file = '{}/{}{}.txt'.format(runs, DATASET, fold_num)

            current_time = datetime.datetime.now().strftime(
                '%Y-%m-%d_%I-%M-%S-%p')

            if is_write:
                if os.path.exists(result_file):
                    f = open(result_file, 'a')
                    f.write(current_time + "\n")
                    f.write("seed: {}\n".format(args.seed))
                    f.write("{}{} \n".format(DATASET, fold_num))
                    f.write("{} epoch".format(i))
                    f.write("close-set accuracy: {} \n".format(closed_acc))
                    f.write("AUROC: {} \n".format(auc_d))
                    f.close()
                else:
                    f = open(result_file, 'w')
                    f.write(current_time + "\n")
                    f.write("seed: {}\n".format(args.seed))
                    f.write("{}{} \n".format(DATASET, fold_num))
                    f.write("{} epoch".format(i))
                    f.write("close-set accuracy: {} \n".format(closed_acc))
                    f.write("AUROC: {} \n".format(auc_d))
                    f.close()
    else:
        PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/hybrid-tiny_imagenet-known-20-split0-2020-09-21_05-49-50-PM"
        PATH = "{}/{}{}_hybrid".format(PATH_1, DATASET, fold_num)

        test_encoder = encoder32()
        test_encoder.to(device)
        test_encoder.load_state_dict(
            torch.load("{}_encoder_latest.pth".format(PATH)))

        test_classifier = classifier32()
        test_classifier.to(device)
        test_classifier.load_state_dict(
            torch.load("{}_classifier_latest.pth".format(PATH)))

        test_flow = ResidualFlow(n_classes=20,
                                 input_size=(64, 128, 4, 4),
                                 n_blocks=[32, 32, 32],
                                 intermediate_dim=512,
                                 factor_out=False,
                                 quadratic=False,
                                 init_layer=None,
                                 actnorm=True,
                                 fc_actnorm=False,
                                 dropout=0,
                                 fc=False,
                                 coeff=0.98,
                                 vnorms='2222',
                                 n_lipschitz_iters=None,
                                 sn_atol=1e-3,
                                 sn_rtol=1e-3,
                                 n_power_series=None,
                                 n_dist='poisson',
                                 n_samples=1,
                                 kernels='3-1-3',
                                 activation_fn='swish',
                                 fc_end=True,
                                 n_exact_terms=2,
                                 preact=True,
                                 neumann_grad=True,
                                 grad_in_forward=False,
                                 first_resblock=True,
                                 learn_p=False,
                                 classification='hybrid',
                                 classification_hdim=256,
                                 block_type='resblock')

        test_flow.to(device)

        with torch.no_grad():
            x = torch.rand(1, *input_size[1:]).to(device)
            test_flow(x)
        checkpt = torch.load("{}_flow_latest.pth".format(PATH))
        sd = {
            k: v
            for k, v in checkpt['state_dict'].items()
            if 'last_n_samples' not in k
        }
        state = test_flow.state_dict()
        state.update(sd)
        test_flow.load_state_dict(state, strict=True)

        hybrid = HybridModel(test_encoder, test_classifier, test_flow)

        closed_acc = evalute_classifier(hybrid, closed_testloader)
        print("closed-set accuracy: ", closed_acc)
        auc_d = evaluate_openset(hybrid, closed_testloader, open_testloader)
        print("auc discriminator: ", auc_d)
Ejemplo n.º 7
0
class TrainingLoop:
    def __init__(self,
                 model,
                 train_loader,
                 val_loader,
                 plot_func=None,
                 loss_meters=None,
                 loss_hists=None):
        """Initialize main training loop for Neural ODE model.

        Dataloaders should return tuple of data and timepoints with shapes:
            ((B x L x D), (B x L)) where
            B = Batch size, L = Number of observations, D = Data dimension.

        Args:
            model (nn.Module): Model to train.
            train_loader (torch.utils.data.Dataloader): Training data loader.
            val_loader (torch.utils.data.Dataloader): Validation data loader.
            plot_func (function): Function used to plot predictions.
            loss_meters (RunningAverageMeter, RunningAverageMeter):
                Existing training / val loss average meters.
            loss_hists (list, list): Train/val loss history arrays.
        """

        self.model = model
        self.init_loss_history(loss_hists)
        self.init_loss_meter(loss_meters)

        self.train_loader = train_loader
        self.val_loader = val_loader

        self.plot_func = plot_func
        self.execution_arg_history = []

        self.runtimes = []

    def init_loss_history(self, loss_hist):
        if loss_hist is None:
            self.train_loss_hist = []
            self.val_loss_hist = []
        else:
            self.train_loss_hist = loss_hist[0]
            self.val_loss_hist = loss_hist[1]

    def init_loss_meter(self, loss_meters):
        if loss_meters is None:
            self.train_loss_meter = RunningAverageMeter()
            self.val_loss_meter = RunningAverageMeter()
        else:
            self.train_loss_meter = loss_meters[0]
            self.val_loss_meter = loss_meters[1]

    def save_checkpoint(self,
                        optim,
                        scheduler,
                        epoch,
                        epoch_times,
                        ckpt_path=None):
        if ckpt_path is None:
            ckpt_path = get_checkpoint_path()

        scheduler_sd = scheduler.state_dict() if scheduler else None

        torch.save(
            {
                'model_state_dict': self.model.state_dict(),
                'optim_state_dict': optim.state_dict(),
                'scheduler_state_dict': scheduler_sd,
                'epoch': epoch,
                'epoch_times': epoch_times,
                'loss_hists': [self.train_loss_hist, self.val_loss_hist],
                'loss_meters': [self.train_loss_meter, self.val_loss_meter],
            }, ckpt_path)

    def load_checkpoint(self, ckpt_path=None):
        if ckpt_path is None:
            ckpt_path = get_checkpoint_path()

        ckpt = torch.load(ckpt_path)

        self.model.load_state_dict(ckpt['model_state_dict'])
        self.init_loss_history(ckpt['loss_hists'])
        self.init_loss_meter(ckpt['loss_meters'])

        epoch_times = ckpt['epoch_times']

        epoch = ckpt['epoch']
        optim_sd = ckpt['optim_state_dict']
        scheduler_sd = ckpt['scheduler_state_dict']

        return epoch, epoch_times, optim_sd, scheduler_sd

    def train(self,
              optimizer,
              args,
              scheduler=None,
              verbose=True,
              plt_traj=False,
              plt_loss=False):
        """Execute main training loop for Neural ODE model.

        Args:
            optimizer (torch.optim.Optimizer): Optimizer.
            args (dict): Additional training arguments. See below.
            scheduler (torch._LRScheduler): Learning rate scheduler.
            verbose (bool): Prints verbose out training information.
            plt_traj (bool): Plot reconstructions.
            plt_loss (bool): Plot loss history.

        Additional Args:
            args['max_epochs'] (int): Maximum training epochs.
            args['l_std'] (float): Std used to calculate likelihood in ELBO.
            args['clip_norm] (float): Max norm used to clip gradients.
            args['model_atol'] (float): Absolute tolerance used by ODE solve.
            args['model_rtol'] (float): Relative tolerance used by ODE solve.
            args['method'] (str): ODE solver used for ODE solve.
            args['plt_args'] (dict): Plotting arguments.
            args['ckpt_int'] (int): Number of epochs between checkpoints.
        """
        self.execution_arg_history.append(args)

        start_epoch = 1
        epoch_times = []

        if 'ckpt_int' in args and exists_checkpoint():
            ckpt = self.load_checkpoint()

            start_epoch = ckpt[0]
            epoch_times = ckpt[1]

            optimizer.load_state_dict(ckpt[2])
            if scheduler:
                scheduler.load_state_dict(ckpt[3])

        for epoch in range(start_epoch, args['max_epoch'] + 1):
            start_time = time.time()

            for b_data, b_time in self.train_loader:
                optimizer.zero_grad()

                out = self.model.forward(b_data, b_time[0], args)
                elbo = self.model.get_elbo(b_data, *out, args)

                self.train_loss_meter.update(elbo.item())

                elbo.backward()
                if 'clip_norm' in args:
                    clip_grad_norm_(self.model.parameters(), args['clip_norm'])
                optimizer.step()
            if scheduler:
                scheduler.step()

            end_time = time.time()
            epoch_times.append(end_time - start_time)

            with torch.no_grad():
                self.update_val_loss(args)

                if self.plot_func and plt_traj:
                    self.plot_val_traj(args['plt_args'])
                self.train_loss_hist.append(self.train_loss_meter.avg)
                self.val_loss_hist.append(self.val_loss_meter.val)

            if verbose:
                if scheduler:
                    print("Current LR: {}".format(scheduler.get_last_lr()),
                          flush=True)
                if plt_loss:
                    self.plot_loss()
                self.print_loss(epoch)

            if 'ckpt_int' in args and epoch % args['ckpt_int']:
                self.save_checkpoint(optimizer, scheduler, epoch, epoch_times)

        self.runtimes.append(epoch_times)

    def update_val_loss(self, args):
        val_data_tt, val_tp_tt = next(iter(self.val_loader))
        val_out = self.model.forward(val_data_tt, val_tp_tt[0], args)
        val_elbo = self.model.get_elbo(val_data_tt, *val_out, args)

        self.val_loss_meter.update(val_elbo.item())

    def print_loss(self, epoch):
        print('Epoch: {}, Train ELBO: {:.3f}, Val ELBO: {:.3f}'.format(
            epoch, -self.train_loss_meter.avg, -self.val_loss_meter.avg),
              flush=True)

    def plot_loss(self):
        train_range = range(len(self.train_loss_hist))
        val_range = range(len(self.val_loss_hist))
        plt.plot(train_range, self.train_loss_hist, label='train')
        plt.plot(val_range, self.val_loss_hist, label='validation')
        plt.legend()
        plt.show()

    def plot_val_traj(self, args):
        val_data_tt, val_tp_tt = next(iter(self.val_loader))
        self.plot_func(self.model, val_data_tt, val_tp_tt, **args)
Ejemplo n.º 8
0
def trainer(model,
            logger,
            loader,
            args,
            data="mnist",
            optimizer=None,
            scheduler=None,
            adv_train=None,
            tboard=True,
            **kwargs):
    # loader : train_loader, train_eval_loader, test_loader
    logger.info("=" * 80)
    logger.info("Train Info")
    logger.info("Model : {}".format(args.model))
    logger.info("Number of blocks : {}".format(args.block))
    logger.info("Number of parameters : {}".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    start_time = time.time()
    best_acc = 0.
    device = args.device

    criterion = nn.CrossEntropyLoss().to(args.device)
    logger.info("Criterion : {}".format(criterion.__class__.__name__))
    logger.info("Adversarial Training : {}".format(adv_train))
    logger.info("=" * 80)
    data_gen = inf_generator(loader['train_loader'])
    batches_per_epoch = len(loader['train_loader'])

    best_acc = 0.
    best_loss = 1000.
    best_acc_epoch = 0
    best_loss_epoch = 0
    batch_time_meter = RunningAverageMeter()
    end_time = time.time()
    if args.hist:
        hist_dict = dict()

    adv, stats = adv_train_module(adv_train, model, data, args.iters,
                                  args.device)

    torch.save({
        "state_dict": model.state_dict(),
        "args": args
    }, os.path.join(args.save, "model_acc.pt"))
    torch.save({
        "state_dict": model.state_dict(),
        "args": args
    }, os.path.join(args.save, "model_loss.pt"))
    for itr in tqdm.tqdm(range(args.epochs * batches_per_epoch)):

        if itr % batches_per_epoch == 0 and scheduler is not None:
            scheduler.step()

        model.train()
        optimizer.zero_grad()
        x, y = data_gen.__next__()
        x = x.to(args.device)
        y = y.to(args.device)

        if adv_train is not None:
            x = adv.perturb(x, y, device=args.device)
            if adv_train == "ball":
                y = torch.cat([y for _ in range(stats[0])])
        model.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)

        loss.backward()
        optimizer.step()

        batch_time_meter.update(time.time() - end_time)
        end_time = time.time()

        if itr % batches_per_epoch == 0:
            model.eval()
            with torch.no_grad():
                train_acc, train_loss = accuracy(
                    model,
                    dataset_loader=loader['train_eval_loader'],
                    device=args.device,
                    criterion=criterion)
                val_acc, val_loss = accuracy(
                    model,
                    dataset_loader=loader['test_loader'],
                    device=args.device,
                    criterion=criterion)
                if val_acc >= best_acc:
                    torch.save({
                        "state_dict": model.state_dict(),
                        "args": args
                    }, os.path.join(args.save, "model_acc.pt"))
                    best_acc = val_acc
                    best_acc_epoch = int(itr // batches_per_epoch)
                if val_loss <= best_loss:
                    torch.save({
                        "state_dict": model.state_dict(),
                        "args": args
                    }, os.path.join(args.save, "model_loss.pt"))
                    best_loss = val_loss
                    best_loss_epoch = int(itr // batches_per_epoch)
                logger.info(
                    "Epoch {:03d} | Time {:.3f} ({:.3f}) | Train loss {:.4f} | Validation loss {:.4f} | Train Acc {:.4f} | Validation Acc {:.4f}"
                    .format(int(itr // batches_per_epoch),
                            batch_time_meter.val, batch_time_meter.avg,
                            train_loss, val_loss, train_acc, val_acc))
            torch.save({
                "state_dict": model.state_dict(),
                "args": args
            }, os.path.join(args.save, "model_final.pt"))

    torch.save({
        "state_dict": model.state_dict(),
        "args": args
    }, os.path.join(args.save, "model_final.pt"))
    if args.hist:
        with open(os.path.join(args.save, "history.json"), "w") as f:
            json.dump(hist_dict, f)

    logger.info("=" * 80)
    logger.info("Required Time : {:03d} minute {:.2f} seconds".format(
        int((time.time() - start_time) // 60),
        (time.time() - start_time) % 60))
    logger.info("Best Acc Epoch : {:03d}".format(best_acc_epoch))
    logger.info("Best Validation Accuracy : {:.4f}".format(best_acc))
    logger.info("Best loss Epoch : {:03d}".format(best_loss_epoch))
    logger.info("Best Validation loss : {:.4f}".format(best_loss))
    logger.info("Train end")
    logger.info("=" * 80)

    return model
def main(datafile='./data/train_.pt',
         epochs=1000,
         learning_rate=1e-3,
         dim_out=10,
         device='cuda:0',
         project_name='em_showers_net_training',
         work_space='schattengenie',
         graph_embedder='GraphNN_KNN_v2',
         edge_classifier='EdgeClassifier_v1',
         patience=10):

    experiment = Experiment(project_name=project_name, workspace=work_space)

    early_stopping = EarlyStopping_(patience=patience, verbose=True)

    device = torch.device(device)
    showers = preprocess_dataset(datafile)
    showers_train, showers_test = train_test_split(showers, random_state=1337)

    train_loader = DataLoader(showers_train, batch_size=1, shuffle=True)
    test_loader = DataLoader(showers_test, batch_size=1, shuffle=True)

    k = showers[0].x.shape[1]
    print(k)
    graph_embedder = str_to_class(graph_embedder)(dim_out=dim_out,
                                                  k=k).to(device)
    edge_classifier = str_to_class(edge_classifier)(dim_out=dim_out).to(device)

    criterion = FocalLoss(gamma=2.)
    optimizer = torch.optim.Adam(list(graph_embedder.parameters()) +
                                 list(edge_classifier.parameters()),
                                 lr=learning_rate)

    loss_train = RunningAverageMeter()
    loss_test = RunningAverageMeter()
    roc_auc_test = RunningAverageMeter()
    pr_auc_test = RunningAverageMeter()
    acc_test = RunningAverageMeter()
    class_disbalance = RunningAverageMeter()

    for _ in tqdm(range(epochs)):
        for shower in train_loader:
            shower = shower.to(device)
            edge_labels_true, edge_labels_predicted = predict_one_shower(
                shower,
                graph_embedder=graph_embedder,
                edge_classifier=edge_classifier)
            # calculate the batch loss
            loss = criterion(edge_labels_predicted, edge_labels_true.float())
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_train.update(loss.item())
            class_disbalance.update((edge_labels_true.sum().float() /
                                     len(edge_labels_true)).item())

        y_true_list = deque()
        y_pred_list = deque()
        for shower in test_loader:
            shower = shower.to(device)
            edge_labels_true, edge_labels_predicted = predict_one_shower(
                shower,
                graph_embedder=graph_embedder,
                edge_classifier=edge_classifier)

            # calculate the batch loss
            loss = criterion(edge_labels_predicted, edge_labels_true.float())
            y_true, y_pred = edge_labels_true.detach().cpu().numpy(
            ), edge_labels_predicted.detach().cpu().numpy()
            y_true_list.append(y_true)
            y_pred_list.append(y_pred)
            acc = accuracy_score(y_true, y_pred.round())
            roc_auc = roc_auc_score(y_true, y_pred)
            pr_auc = average_precision_score(y_true, y_pred)
            loss_test.update(loss.item())
            acc_test.update(acc)
            roc_auc_test.update(roc_auc)
            pr_auc_test.update(pr_auc)
            class_disbalance.update((edge_labels_true.sum().float() /
                                     len(edge_labels_true)).item())

        #f = plot_aucs(y_true=y_true, y_pred=y_pred)
        #experiment.log_figure("Optimization dynamic", f, overwrite=True)
        experiment_key = experiment.get_key()

        eval_loss = loss_test.val
        early_stopping(eval_loss, graph_embedder, edge_classifier,
                       experiment_key)

        ####
        if early_stopping.early_stop:
            print("Early stopping")
            break
        # TODO: save best
        #torch.save(graph_embedder.state_dict(), "graph_embedder_{}.pt".format(experiment_key))
        #torch.save(edge_classifier.state_dict(), "edge_classifier_{}.pt".format(experiment_key))

        experiment.log_metric('loss_test', loss_test.val)
        experiment.log_metric('acc_test', acc_test.val)
        experiment.log_metric('roc_auc_test', roc_auc_test.val)
        experiment.log_metric('pr_auc_test', pr_auc_test.val)
        experiment.log_metric('class_disbalance', class_disbalance.val)

        y_true = np.concatenate(y_true_list)
        y_pred = np.concatenate(y_pred_list)

    # load the last checkpoint with the best model
    graph_embedder.load_state_dict(
        torch.load("graph_embedder_{}.pt".format(experiment_key)))
    edge_classifier.load_state_dict(
        torch.load("edge_classifier_{}.pt".format(experiment_key)))
Ejemplo n.º 10
0
    def test(self, config, best=False, return_results=True):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        losses = RunningAverageMeter()
        top1 = RunningAverageMeter()
        top5 = RunningAverageMeter()

        keep_track_of_results = return_results or self.use_wandb

        if best:
            self.load_checkpoints(best=True, inplace=True, verbose=False)

        if not hasattr(self, 'test_loader'):
            kwargs = {}
            if not config.disable_cuda and torch.cuda.is_available():
                kwargs = {'num_workers': 4, 'pin_memory': True}
            data_dict = get_dataset(config.dataset, config.data_dir, 'test')
            kwargs.update(data_dict)
            self.test_loader = get_test_loader(batch_size=config.batch_size,
                                               **kwargs)

        if keep_track_of_results:
            results = {}
            all_accs = []

        for net, model_name in zip(self.nets, self.model_names):
            net.eval()

            if self.progress_bar:
                pbar = tqdm(total=len(self.test_loader.dataset),
                            leave=False,
                            desc=f'Testing {model_name}')

                for i, (images, labels, _, _) in enumerate(self.test_loader):
                    if self.use_gpu:
                        images, labels = images.cuda(), labels.cuda()
                    images, labels = Variable(images), Variable(labels)

                    # forward pass
                    with torch.no_grad():
                        outputs = net(images)
                    loss = self.loss_ce(outputs, labels).mean()

                    # measure accuracy and record loss
                    prec_at_1, prec_at_5 = accuracy(outputs.data,
                                                    labels.data,
                                                    topk=(1, 5))
                    losses.update(loss.item(), images.size()[0])
                    top1.update(prec_at_1.item(), images.size()[0])
                    top5.update(prec_at_5.item(), images.size()[0])

                    if self.progress_bar:
                        pbar.update(self.test_loader.batch_size)
                if self.progress_bar:
                    pbar.write(
                        '[*] {:5}: Test loss: {:.3f}, top1_acc: {:.3f}%, top5_acc: {:.3f}%'
                        .format(model_name, losses.avg, top1.avg, top5.avg))
                    pbar.close()

            fold = 'best' if best else 'last'

            if self.use_wandb:
                wandb.run.summary[f"{fold} test acc {model_name}"] = top1.avg

            if keep_track_of_results:
                results[f'{model_name} test loss'] = losses.avg
                results[f'{model_name} test acc @ 1'] = top1.avg
                results[f'{model_name} test acc @ 5'] = top5.avg
                all_accs.append(top1.avg)

        if keep_track_of_results:
            results['average test acc'] = sum(all_accs) / len(all_accs)
            results['min test acc'] = min(all_accs)
            results['max test acc'] = max(all_accs)

        if best:
            self.load_checkpoints(best=False, inplace=True, verbose=False)

        if self.use_wandb:
            wandb.log(results)

        if return_results:
            return results