Пример #1
0
def main():
    global best_test_bpd

    last_checkpoints = []
    lipschitz_constants = []
    ords = []

    # if args.resume:
    #     validate(args.begin_epoch - 1, model, ema)
    for epoch in range(args.begin_epoch, args.nepochs):

        logger.info('Current LR {}'.format(optimizer.param_groups[0]['lr']))

        train(epoch, model)
        lipschitz_constants.append(get_lipschitz_constants(model))
        ords.append(get_ords(model))
        logger.info('Lipsh: {}'.format(pretty_repr(lipschitz_constants[-1])))
        logger.info('Order: {}'.format(pretty_repr(ords[-1])))

        if args.ema_val:
            test_bpd = validate(epoch, model, ema)
        else:
            test_bpd = validate(epoch, model)

        if args.scheduler and scheduler is not None:
            scheduler.step()

        if test_bpd < best_test_bpd:
            best_test_bpd = test_bpd
            utils.save_checkpoint(
                {
                    'state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'args': args,
                    'ema': ema,
                    'test_bpd': test_bpd,
                },
                os.path.join(args.save, 'models'),
                epoch,
                last_checkpoints,
                num_checkpoints=5)

        #torch.save({
        #    'state_dict': model.state_dict(),
        #    'optimizer_state_dict': optimizer.state_dict(),
        #    'args': args,
        #    'ema': ema,
        #    'test_bpd': test_bpd,
        #}, os.path.join(args.save, 'models', 'most_recent.pth'))

        torch.save(
            {
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'args': args,
                'ema': ema,
                'test_bpd': test_bpd,
            }, os.path.join(args.save, 'models', '00mostReRecent.pth'))
Пример #2
0
 def _save_checkpoint(self, epoch, save_dir, is_best=False):
     save_checkpoint(
         {
             'state_dict': self.model.state_dict(),
             'epoch': epoch + 1,
             'optimizer': self.optimizer.state_dict(),
         },
         save_dir,
         is_best=is_best)
Пример #3
0
            args.optimizer.load_state_dict(checkpoint['optimizer'])
            print "=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch'])
        else:
            print "=> no checkpoint found at '{}'".format(args.resume)

    res_file = '../results/%s.txt' % args.name
    if not osp.exists('../experiment/results'):
        os.makedirs('../experiment/results')
    save_dir = '../models/%s' % args.name
    if not osp.exists(save_dir):
        os.makedirs(save_dir)

    headers = [
        "Epoch", "Pre R@50", "ZS", "R@100", "ZS", "Rel R@50", "ZS", "R@100",
        "ZS"
    ]
    res = []
    for epoch in range(args.start_epoch, args.epochs):
        train_net(train_data_layer, net, epoch, args)
        res.append((epoch, ) + test_pre_net(net, args) +
                   test_rel_net(net, args))
        with open(res_file, 'w') as f:
            f.write(tabulate(res, headers))
        save_checkpoint(
            '%s/epoch_%d_checkpoint.pth.tar' % (save_dir, epoch), {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer': args.optimizer.state_dict(),
            })
Пример #4
0
                                        'ffii{:04d}.jpg'.format(itr))
            print('')

            print(fig_filename)
            print('')

            utils.makedirs(os.path.dirname(fig_filename))
            plt.savefig(fig_filename)

            #plt.ion()
            #plt.show()

            #plt.pause(0.1)
            #plt.close()

            utils.save_checkpoint({'state_dict': genGen.state_dict()},
                                  os.path.join(args.save, 'myModels3'), itr)
            #utils.save_checkpoint({'state_dict': genGen.state_dict()}, os.path.join(args.save, 'myModels'), args.niters2)

        #loss_meter.update(loss.item())
        #logpz_meter.update(logpz.item())

        loss2_meter.update(lossGen.item())

        #delta_logp_meter.update(delta_logp.item())

        #loss.backward()
        #lossGen.backward()

        #lossGen.backward(create_graph=True)
        lossGen.backward()
def main():

    global args, best_prec1
    args = parser.parse_args()

    my_whole_seed = 222
    random.seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.manual_seed(my_whole_seed)
    torch.cuda.manual_seed_all(my_whole_seed)
    torch.cuda.manual_seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(my_whole_seed)

    for kk_time in range(args.seedstart, args.seedend):
        args.seed = kk_time
        args.result = args.result + str(args.seed)

        # create model
        from models.resnet_sup import resnet18, resnet50, resnet34
        model = resnet18()

        # pretrain_dict = torch.load("resnet18-5c106cde.pth")
        # model_dict = model.state_dict()
        # pretrained_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict}
        # pretrained_dict.pop("fc.weight")
        # pretrained_dict.pop("fc.bias")
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)

        model = torch.nn.DataParallel(model).cuda()
        model_weights = torch.load(
            "exp/fundus_dr/DR_miccai_repeat0/fold0-epoch-800.pth.tar")
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in model_weights["state_dict"].items() if k in model_dict
        }
        pretrained_dict.pop("module.fc.weight")
        pretrained_dict.pop("module.fc.bias")
        model.load_state_dict(pretrained_dict, strict=False)

        # Data loading code
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            # transforms.RandomGrayscale(p=0.2),
            # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ])
        # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)),
        #                           transforms.RandomHorizontalFlip(p=0.5),
        #                           get_color_distortion(s=1),
        #                           transforms.Lambda(lambda x: gaussian_blur(x)),
        #                           transforms.ToTensor(),
        #                           normalize])
        # aug = transforms.Compose([transforms.RandomRotation(60),
        #                           transforms.RandomResizedCrop(224, scale=(0.6, 1.)),
        #                           transforms.RandomGrayscale(p=0.2),
        #                           transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        #                           transforms.RandomHorizontalFlip(),
        #                           transforms.ToTensor(),
        #                             normalize])
        aug_test = transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(), normalize])

        # dataset
        import datasets.fundus_amd_syn_crossvalidation as medicaldata
        train_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug,
                                                 train=True,
                                                 args=args)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=4,
            drop_last=True if args.multiaug else False,
            worker_init_fn=random.seed(my_whole_seed))

        valid_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug_test,
                                                 train=False,
                                                 args=args)
        val_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
            worker_init_fn=random.seed(my_whole_seed))

        criterion = nn.CrossEntropyLoss().cuda()
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                model_dict = model.state_dict()

                pretrained_dict = {
                    k: v
                    for k, v in checkpoint["state_dict"].items()
                    if k in model_dict
                }
                pretrained_dict.pop("module.fc.weight")
                pretrained_dict.pop("module.fc.bias")
                # pretrained_dict = {k: v for k, v in checkpoint["net"].items() if k in model_dict}
                # pretrained_dict.pop("module.conv1.weight")
                # pretrained_dict.pop("module.conv1.bias")

                model_dict.update(pretrained_dict)
                model.load_state_dict(model_dict)
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))

            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        # mkdir result folder and tensorboard
        os.makedirs(args.result, exist_ok=True)
        writer = SummaryWriter("runs/" + str(args.result.split("/")[-1]))
        writer.add_text('Text', str(args))

        # copy code
        import shutil, glob
        source = glob.glob("*.py")
        source += glob.glob("*/*.py")
        os.makedirs(args.result + "/code_file", exist_ok=True)
        for file in source:
            name = file.split("/")[0]
            if name == file:
                shutil.copy(file, args.result + "/code_file/")
            else:
                os.makedirs(args.result + "/code_file/" + name, exist_ok=True)
                shutil.copy(file, args.result + "/code_file/" + name)

        for epoch in range(args.start_epoch, args.epochs):
            lr = adjust_learning_rate(optimizer, epoch, args,
                                      [500, 1000, 1500])
            writer.add_scalar("lr", lr, epoch)

            # # train for one epoch
            loss = train(train_loader, model, criterion, optimizer)
            writer.add_scalar("train_loss", loss, epoch)

            gap_int = 200
            if (epoch) % gap_int == 0:
                loss_val, auc, acc, precision, recall, f1score = supervised_evaluation(
                    model, val_loader)
                writer.add_scalar("test_auc", auc, epoch)
                writer.add_scalar("test_acc", acc, epoch)
                writer.add_scalar("test_precision", precision, epoch)
                writer.add_scalar("test_recall", recall, epoch)
                writer.add_scalar("test_f1score", f1score, epoch)

                # save to txt
                f = open(args.result + "/result.txt", "a+")
                f.write("epoch " + str(epoch) + "\n")
                f.write("auc: %.4f\n" % (auc))
                f.write("acc: %.4f\n" % (acc))
                f.write("pre: %.4f\n" % (precision))
                f.write("recall: %.4f\n" % (recall))
                f.write("f1score: %.4f\n" % (f1score))
                f.close()

                # save checkpoint
            if epoch in [1000, 2000, 3000]:
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    },
                    filename=args.result + "/epoch-" + str(epoch) + ".pth.tar")
Пример #6
0
def main():
    global best_test_bpd

    last_checkpoints = []
    lipschitz_constants = []
    ords = []

    # if args.resume:
    #     validate(args.begin_epoch - 1, model, ema)

    #liveloss = PlotLosses()

    #liveloss = PlotLosses()
    liveloss = PlotLosses()

    for epoch in range(args.begin_epoch, args.nepochs):
        logs = {}

        logger.info('Current LR {}'.format(optimizer.param_groups[0]['lr']))

        running_loss = train(epoch, model)

        #train(epoch, model)
        lipschitz_constants.append(get_lipschitz_constants(model))

        #ords.append(get_ords(model))

        #ords.append(get_ords(model))
        ords.append(get_ords(model))

        logger.info('Lipsh: {}'.format(pretty_repr(lipschitz_constants[-1])))
        logger.info('Order: {}'.format(pretty_repr(ords[-1])))

        #epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_loss = running_loss / len(
            datasets.CIFAR10(
                args.dataroot, train=True, transform=transform_train))

        logs['log loss'] = epoch_loss.item()

        liveloss.update(logs)
        liveloss.draw()

        if args.ema_val:
            test_bpd = validate(epoch, model, ema)
        else:
            test_bpd = validate(epoch, model)

        if args.scheduler and scheduler is not None:
            scheduler.step()

        if test_bpd < best_test_bpd:
            best_test_bpd = test_bpd

            utils.save_checkpoint(
                {
                    'state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'args': args,
                    'ema': ema,
                    'test_bpd': test_bpd,
                },
                os.path.join(args.save, 'moMoModels'),
                epoch,
                last_checkpoints,
                num_checkpoints=5)
            """
            utils.save_checkpoint({
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'args': args,
                'ema': ema,
                'test_bpd': test_bpd,
            }, os.path.join(args.save, 'mMoModels'), epoch, last_checkpoints, num_checkpoints=5)
            
            utils.save_checkpoint({
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'args': args,
                'ema': ema,
                'test_bpd': test_bpd,
            }, os.path.join(args.save, 'mModels'), epoch, last_checkpoints, num_checkpoints=5)
            
            utils.save_checkpoint({
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'args': args,
                'ema': ema,
                'test_bpd': test_bpd,
            }, os.path.join(args.save, 'models'), epoch, last_checkpoints, num_checkpoints=5)
            """

        torch.save(
            {
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'args': args,
                'ema': ema,
                'test_bpd': test_bpd,
            }, os.path.join(args.save, 'models',
                            '010mmoosttMoosttRecentt.pth'))
        """
Пример #7
0
def main():

    logger.info('Start to declare training variable')
    cfg.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    logger.info('Session will be ran in device: [%s]' % cfg.device)
    start_epoch = 0
    best_acc = 0.

    logger.info('Start to prepare data')
    # get transformers
    # train_transform is for data perturbation
    train_transform = transforms.get(train=True)
    # test_transform is for evaluation
    test_transform = transforms.get(train=False)
    # reduced_transform is for original training data
    reduced_transform = get_reduced_transform(cfg.tfm_resize, cfg.tfm_size,
                                              cfg.tfm_means, cfg.tfm_stds)
    # get datasets
    # each head should have its own trainset
    train_splits = dict(cifar100=[['train', 'test']],
                        stl10=[['train+unlabeled', 'test'], ['train', 'test']])
    test_splits = dict(cifar100=['train', 'test'], stl10=['train', 'test'])
    # instance dataset for each head
    # otrainset: original trainset
    otrainset = [
        ConcatDataset([
            datasets.get(split=split, transform=reduced_transform)
            for split in train_splits[cfg.dataset][hidx]
        ]) for hidx in xrange(len(train_splits[cfg.dataset]))
    ]
    # ptrainset: perturbed trainset
    ptrainset = [
        ConcatDataset([
            datasets.get(split=split, transform=train_transform)
            for split in train_splits[cfg.dataset][hidx]
        ]) for hidx in xrange(len(train_splits[cfg.dataset]))
    ]
    # testset
    testset = ConcatDataset([
        datasets.get(split=split, transform=test_transform)
        for split in test_splits[cfg.dataset]
    ])
    # declare data loaders for testset only
    test_loader = DataLoader(testset,
                             batch_size=cfg.batch_size,
                             shuffle=False,
                             num_workers=cfg.num_workers)

    logger.info('Start to build model')
    net = networks.get()
    criterion = PUILoss(cfg.pica_lamda)
    optimizer = optimizers.get(
        params=[val for _, val in net.trainable_parameters().iteritems()])
    lr_handler = lr_policy.get()

    # load session if checkpoint is provided
    if cfg.resume:
        assert os.path.exists(cfg.resume), "Resume file not found"
        ckpt = torch.load(cfg.resume)
        logger.info('Start to resume session for file: [%s]' % cfg.resume)
        net.load_state_dict(ckpt['net'])
        best_acc = ckpt['acc']
        start_epoch = ckpt['epoch']

    # move modules to target device
    net, criterion = net.to(cfg.device), criterion.to(cfg.device)

    # tensorboard wrtier
    writer = SummaryWriter(cfg.debug, log_dir=cfg.tfb_dir)
    # start training
    lr = cfg.base_lr
    epoch = start_epoch
    while lr > 0 and epoch < cfg.max_epochs:

        lr = lr_handler.update(epoch, optimizer)
        writer.add_scalar('Train/Learing_Rate', lr, epoch)

        logger.info('Start to train at %d epoch with learning rate %.5f' %
                    (epoch, lr))
        train(epoch, net, otrainset, ptrainset, optimizer, criterion, writer)

        logger.info('Start to evaluate after %d epoch of training' % epoch)
        acc, nmi, ari = evaluate(net, test_loader)
        logger.info('Evaluation results at epoch %d are: '
                    'ACC: %.3f, NMI: %.3f, ARI: %.3f' % (epoch, acc, nmi, ari))
        writer.add_scalar('Evaluate/ACC', acc, epoch)
        writer.add_scalar('Evaluate/NMI', nmi, epoch)
        writer.add_scalar('Evaluate/ARI', ari, epoch)

        epoch += 1

        if cfg.debug:
            continue

        # save checkpoint
        is_best = acc > best_acc
        best_acc = max(best_acc, acc)
        save_checkpoint(
            {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'acc': acc,
                'epoch': epoch
            },
            is_best=is_best)

    logger.info('Done')
Пример #8
0
    if args.mode == 'train':
        LOGGER.info("Starting training ...")
        for epoch in range(10):
            trainer.train(epoch=epoch)
            score = trainer.test(epoch=epoch, val=True)
            LOGGER.info("F2-Score: {}".format(score))
            if score > best_score:
                is_best = True
                best_score = score
            else:
                is_best = False
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'f2_score': score,
                    'state_dict': trainer.model.state_dict()
                },
                experiment_path,
                backup_as_best=is_best)

    elif args.mode == 'finetune':
        LOGGER.info("Starting fine-tuning ...")
        trainer.lower_lr()
        for epoch in range(100):
            trainer.finetune(epoch=epoch)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'f2_score': 0,
                    'state_dict': trainer.model.state_dict()
                },
Пример #9
0
    def train_gan(self, generator, discriminator):
        """ Implements Training routine for ESRGAN
        Args:
          generator: Model object for the Generator
          discriminator: Model object for the Discriminator
    """
        phase_args = self.settings["train_combined"]
        decay_args = phase_args["adam"]["decay"]
        decay_factor = decay_args["factor"]
        decay_steps = decay_args["step"]
        lambda_ = phase_args["lambda"]
        hr_dimension = self.settings["dataset"]["hr_dimension"]
        eta = phase_args["eta"]
        total_steps = phase_args["num_steps"]
        optimizer = partial(tf.optimizers.Adam,
                            learning_rate=phase_args["adam"]["initial_lr"],
                            beta_1=phase_args["adam"]["beta_1"],
                            beta_2=phase_args["adam"]["beta_2"])

        G_optimizer = optimizer()
        D_optimizer = optimizer()

        ra_gen = utils.RelativisticAverageLoss(discriminator, type_="G")
        ra_disc = utils.RelativisticAverageLoss(discriminator, type_="D")

        # The weights of generator trained during Phase #1
        # is used to initialize or "hot start" the generator
        # for phase #2 of training
        status = None
        checkpoint = tf.train.Checkpoint(G=generator,
                                         G_optimizer=G_optimizer,
                                         D=discriminator,
                                         D_optimizer=D_optimizer)
        if not tf.io.gfile.exists(
                os.path.join(self.model_dir,
                             self.settings["checkpoint_path"]["phase_2"],
                             "checkpoint")):
            hot_start = tf.train.Checkpoint(G=generator,
                                            G_optimizer=G_optimizer)
            status = utils.load_checkpoint(hot_start, "phase_1",
                                           self.model_dir)
            # consuming variable from checkpoint
            G_optimizer.learning_rate.assign(phase_args["adam"]["initial_lr"])
        else:
            status = utils.load_checkpoint(checkpoint, "phase_2",
                                           self.model_dir)

        logging.debug("phase status object: {}".format(status))

        gen_metric = tf.keras.metrics.Mean()
        disc_metric = tf.keras.metrics.Mean()
        psnr_metric = tf.keras.metrics.Mean()
        logging.debug("Loading Perceptual Model")
        perceptual_loss = utils.PerceptualLoss(
            weights="imagenet",
            input_shape=[hr_dimension, hr_dimension, 3],
            loss_type=phase_args["perceptual_loss_type"])
        logging.debug("Loaded Model")

        def _step_fn(image_lr, image_hr):
            logging.debug("Starting Distributed Step")
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                fake = generator.unsigned_call(image_lr)
                logging.debug("Fetched Generator Fake")
                fake = utils.preprocess_input(fake)
                image_lr = utils.preprocess_input(image_lr)
                image_hr = utils.preprocess_input(image_hr)
                percep_loss = tf.reduce_mean(perceptual_loss(image_hr, fake))
                logging.debug("Calculated Perceptual Loss")
                l1_loss = utils.pixel_loss(image_hr, fake)
                logging.debug("Calculated Pixel Loss")
                loss_RaG = ra_gen(image_hr, fake)
                logging.debug("Calculated Relativistic"
                              "Averate (RA) Loss for Generator")
                disc_loss = ra_disc(image_hr, fake)
                logging.debug("Calculated RA Loss Discriminator")
                gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss
                logging.debug("Calculated Generator Loss")
                disc_metric(disc_loss)
                gen_metric(gen_loss)
                gen_loss = gen_loss * (1.0 / self.batch_size)
                disc_loss = disc_loss * (1.0 / self.batch_size)
                psnr_metric(
                    tf.reduce_mean(tf.image.psnr(fake, image_hr,
                                                 max_val=256.0)))
            disc_grad = disc_tape.gradient(disc_loss,
                                           discriminator.trainable_variables)
            logging.debug("Calculated gradient for Discriminator")
            D_optimizer.apply_gradients(
                zip(disc_grad, discriminator.trainable_variables))
            logging.debug("Applied gradients to Discriminator")
            gen_grad = gen_tape.gradient(gen_loss,
                                         generator.trainable_variables)
            logging.debug("Calculated gradient for Generator")
            G_optimizer.apply_gradients(
                zip(gen_grad, generator.trainable_variables))
            logging.debug("Applied gradients to Generator")

            return tf.cast(D_optimizer.iterations, tf.float32)

        @tf.function
        def train_step(image_lr, image_hr):
            distributed_iterations = self.strategy.experimental_run_v2(
                _step_fn, args=(image_lr, image_hr))
            num_steps = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                             distributed_iterations,
                                             axis=None)
            return num_steps

        start = time.time()
        last_psnr = 0
        while True:
            image_lr, image_hr = next(self.dataset)
            num_step = train_step(image_lr, image_hr)
            if num_step >= total_steps:
                return
            if status:
                status.assert_consumed()
                logging.info("consumed checkpoint successfully!")
                status = None
            # Decaying Learning Rate
            for _step in decay_steps.copy():
                if num_step >= _step:
                    decay_steps.pop(0)
                    g_current_lr = self.strategy.reduce(
                        tf.distribute.ReduceOp.MEAN,
                        G_optimizer.learning_rate,
                        axis=None)

                    d_current_lr = self.strategy.reduce(
                        tf.distribute.ReduceOp.MEAN,
                        D_optimizer.learning_rate,
                        axis=None)

                    logging.debug("Current LR: G = %s, D = %s" %
                                  (g_current_lr, d_current_lr))
                    logging.debug("[Phase 2] Decayed Learing Rate by %f." %
                                  decay_factor)
                    G_optimizer.learning_rate.assign(
                        G_optimizer.learning_rate * decay_factor)
                    D_optimizer.learning_rate.assign(
                        D_optimizer.learning_rate * decay_factor)

            # Writing Summary
            with self.summary_writer_2.as_default():
                tf.summary.scalar("gen_loss",
                                  gen_metric.result(),
                                  step=D_optimizer.iterations)
                tf.summary.scalar("disc_loss",
                                  disc_metric.result(),
                                  step=D_optimizer.iterations)
                tf.summary.scalar("mean_psnr",
                                  psnr_metric.result(),
                                  step=D_optimizer.iterations)

            # Logging and Checkpointing
            if not num_step % self.settings["print_step"]:
                logging.info("Step: {}\tGen Loss: {}\tDisc Loss: {}"
                             "\tPSNR: {}\tTime Taken: {} sec".format(
                                 num_step, gen_metric.result(),
                                 disc_metric.result(), psnr_metric.result(),
                                 time.time() - start))
                # if psnr_metric.result() > last_psnr:
                last_psnr = psnr_metric.result()
                utils.save_checkpoint(checkpoint, "phase_2", self.model_dir)
                start = time.time()
Пример #10
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-d', '--dataset', default='celeba', type=str, help='dataset name',
        choices=['celeba'])
    parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'laplace', 'flow'])
    parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs')
    parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size')
    parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-z', '--latent-dim', default=100, type=int, help='size of latent dimension')
    parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term')
    parser.add_argument('--beta_sens', default=20, type=float, help='Relative importance of predicting sensitive attributes')
    #parser.add_argument('--sens_idx', default=[13, 15, 20], type=list, help='Relative importance of predicting sensitive attributes')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--clf_samps', action='store_true')
    parser.add_argument('--clf_means', action='store_false', dest='clf_samps')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='betatcvae-celeba')
    parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log')
    parser.add_argument('--audit', action='store_true',
            help='after each epoch, audit the repr wrt fair clf task')
    args = parser.parse_args()
    print(args)
    
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    writer = SummaryWriter(args.save)
    writer.add_text('args', json.dumps(vars(args), sort_keys=True, indent=4))

    log_file = os.path.join(args.save, 'train.log')
    if os.path.exists(log_file):
        os.remove(log_file)

    print(vars(args))
    print(vars(args), file=open(log_file, 'w'))

    torch.cuda.set_device(args.gpu)

    # data loader
    loaders = setup_data_loaders(args, use_cuda=True)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
        q_dist = dist.Normal()

    x_dist = dist.Normal() if args.dataset == 'celeba' else dist.Bernoulli()
    a_dist = dist.Bernoulli()
    vae = SensVAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, 
            q_dist=q_dist, include_mutinfo=not args.exclude_mutinfo, 
            tcvae=args.tcvae, conv=args.conv, mss=args.mss, 
            n_chan=3 if args.dataset == 'celeba' else 1, sens_idx=SENS_IDX,
            x_dist=x_dist, a_dist=a_dist, clf_samps=args.clf_samps)

    if args.audit:
        audit_label_fn = get_label_fn(
                dict(data=dict(name='celeba', label_fn='H'))
                )
        audit_repr_fns = dict()
        audit_attr_fns = dict()
        audit_models = dict()
        audit_train_metrics = dict()
        audit_validation_metrics = dict()
        for attr_fn_name in CELEBA_SENS_IDX.keys():
            model = MLPClassifier(args.latent_dim, 1000, 2)
            model.cuda()
            audit_models[attr_fn_name] = model
            audit_repr_fns[attr_fn_name] = get_repr_fn(
                dict(data=dict(
                    name='celeba', repr_fn='remove_all', attr_fn=attr_fn_name))
                )
            audit_attr_fns[attr_fn_name] = get_attr_fn(
                dict(data=dict(name='celeba', attr_fn=attr_fn_name))
                )

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)
    if args.audit:
        Adam = optim.Adam
        audit_optimizers = dict()
        for k, v in audit_models.items():
            audit_optimizers[k] = Adam(v.parameters(), lr=args.learning_rate)


    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=3776)

    train_elbo = []
    train_tc = []

    # training loop
    dataset_size = len(loaders['train'].dataset)
    num_iterations = len(loaders['train']) * args.num_epochs
    iteration = 0
    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    tc_running_mean = utils.RunningAverageMeter()
    clf_acc_meters = {'clf_acc{}'.format(s): utils.RunningAverageMeter() for s in vae.sens_idx}

    val_elbo_running_mean = utils.RunningAverageMeter()
    val_tc_running_mean = utils.RunningAverageMeter()
    val_clf_acc_meters = {'val_clf_acc{}'.format(s): utils.RunningAverageMeter() for s in vae.sens_idx}


    while iteration < num_iterations:
        bar = tqdm(range(len(loaders['train'])))
        for i, (x, a) in enumerate(loaders['train']):
            bar.update()
            iteration += 1
            batch_time = time.time()
            vae.train()
            #anneal_kl(args, vae, iteration)  # TODO try annealing beta/beta_sens
            vae.beta = args.beta
            vae.beta_sens = args.beta_sens
            optimizer.zero_grad()
            # transfer to GPU
            x = x.cuda(async=True)
            a = a.float()
            a = a.cuda(async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            a = Variable(a)
            # do ELBO gradient and accumulate loss
            obj, elbo, metrics = vae.elbo(x, a, dataset_size)
            if utils.isnan(obj).any():
                raise ValueError('NaN spotted in objective.')
            obj.mean().mul(-1).backward()
            elbo_running_mean.update(elbo.mean().data.item())
            tc_running_mean.update(metrics['tc'])
            for (s, meter), (_, acc) in zip(clf_acc_meters.items(), metrics.items()):
                clf_acc_meters[s].update(acc.data.item())
            optimizer.step()

            if args.audit:
                for model in audit_models.values():
                    model.train()
                # now re-encode x and take a step to train each audit classifier
                for opt in audit_optimizers.values():
                    opt.zero_grad()
                with torch.no_grad():
                    zs, z_params = vae.encode(x)
                    if args.clf_samps:
                        z = zs
                    else:
                        z_mu = z_params.select(-1, 0)
                        z = z_mu
                    a_all = a
                for subgroup, model in audit_models.items():
                    # noise out sensitive dims of latent code
                    z_ = z.clone()
                    a_all_ = a_all.clone()
                    # subsample to just sens attr of interest for this subgroup
                    a_ = audit_attr_fns[subgroup](a_all_)
                    # noise out sensitive dims for this subgroup
                    z_ = audit_repr_fns[subgroup](z_, None, None)
                    y_ = audit_label_fn(a_all_).long()

                    loss, _, metrics = model(z_, y_, a_)
                    loss.backward()
                    audit_optimizers[subgroup].step()
                    metrics_dict = {}
                    metrics_dict.update(loss=loss.detach().item())
                    for k, v in metrics.items():
                        if v.numel() > 1:
                            k += '-avg'
                            v = v.float().mean()
                        metrics_dict.update({k:v.detach().item()})
                    audit_train_metrics[subgroup] = metrics_dict

            # report training diagnostics
            if iteration % args.log_freq == 0:
                if args.audit:
                    for subgroup, metrics in audit_train_metrics.items():
                        for metric_name, metric_value in metrics.items():
                            writer.add_scalar(
                                    '{}/{}'.format(subgroup, metric_name),
                                    metric_value, iteration)

                train_elbo.append(elbo_running_mean.avg)
                writer.add_scalar('train_elbo', elbo_running_mean.avg, iteration)
                train_tc.append(tc_running_mean.avg)
                writer.add_scalar('train_tc', tc_running_mean.avg, iteration)
                msg = '[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f) training TC %.4f (%.4f)' % (
                    iteration, time.time() - batch_time, vae.beta, vae.lamb,
                    elbo_running_mean.val, elbo_running_mean.avg,
                    tc_running_mean.val, tc_running_mean.avg)
                for k, v in clf_acc_meters.items():
                    msg += ' {}: {:.2f}'.format(k, v.avg)
                    writer.add_scalar(k, v.avg, iteration)
                print(msg)
                print(msg, file=open(log_file, 'a'))

                vae.eval()
                ################################################################
                # evaluate validation metrics on vae and auditors
                for x, a in loaders['validation']:
                    # transfer to GPU
                    x = x.cuda(async=True)
                    a = a.float()
                    a = a.cuda(async=True)
                    # wrap the mini-batch in a PyTorch Variable
                    x = Variable(x)
                    a = Variable(a)
                    # do ELBO gradient and accumulate loss
                    obj, elbo, metrics = vae.elbo(x, a, dataset_size)
                    if utils.isnan(obj).any():
                        raise ValueError('NaN spotted in objective.')
                    #
                    val_elbo_running_mean.update(elbo.mean().data.item())
                    val_tc_running_mean.update(metrics['tc'])
                    for (s, meter), (_, acc) in zip(
                            val_clf_acc_meters.items(), metrics.items()):
                        val_clf_acc_meters[s].update(acc.data.item())

                if args.audit:
                    for model in audit_models.values():
                        model.eval()
                    with torch.no_grad():
                        zs, z_params = vae.encode(x)
                        if args.clf_samps:
                            z = zs
                        else:
                            z_mu = z_params.select(-1, 0)
                            z = z_mu
                        a_all = a
                    for subgroup, model in audit_models.items():
                        # noise out sensitive dims of latent code
                        z_ = z.clone()
                        a_all_ = a_all.clone()
                        # subsample to just sens attr of interest for this subgroup
                        a_ = audit_attr_fns[subgroup](a_all_)
                        # noise out sensitive dims for this subgroup
                        z_ = audit_repr_fns[subgroup](z_, None, None)
                        y_ = audit_label_fn(a_all_).long()

                        loss, _, metrics = model(z_, y_, a_)
                        loss.backward()
                        audit_optimizers[subgroup].step()
                        metrics_dict = {}
                        metrics_dict.update(val_loss=loss.detach().item())
                        for k, v in metrics.items():
                            k = 'val_' + k  # denote a validation metric
                            if v.numel() > 1:
                                k += '-avg'
                                v = v.float().mean()
                            metrics_dict.update({k:v.detach().item()})
                        audit_validation_metrics[subgroup] = metrics_dict

                # after iterating through validation set, write summaries
                for subgroup, metrics in audit_validation_metrics.items():
                    for metric_name, metric_value in metrics.items():
                        writer.add_scalar(
                                '{}/{}'.format(subgroup, metric_name),
                                metric_value, iteration)
                writer.add_scalar('val_elbo', val_elbo_running_mean.avg, iteration)
                writer.add_scalar('val_tc', val_tc_running_mean.avg, iteration)
                for k, v in val_clf_acc_meters.items():
                    writer.add_scalar(k, v.avg, iteration)

                ################################################################
                # finally, plot training and test ELBOs
                if args.visdom:
                    display_samples(vae, x, vis)
                    plot_elbo(train_elbo, vis)
                    plot_tc(train_tc, vis)

                utils.save_checkpoint({
                    'state_dict': vae.state_dict(),
                    'args': args}, args.save, iteration // len(loaders['train']))
                eval('plot_vs_gt_' + args.dataset)(vae, loaders['train'].dataset,
                    os.path.join(args.save, 'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics after training
    vae.eval()
    utils.save_checkpoint({
        'state_dict': vae.state_dict(),
        'args': args}, args.save, 0)
    dataset_loader = DataLoader(loaders['train'].dataset, batch_size=1000, num_workers=1, shuffle=False)
    if False:
        logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \
            elbo_decomposition(vae, dataset_loader)
        torch.save({
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'marginal_entropies': marginal_entropies,
            'joint_entropy': joint_entropy
        }, os.path.join(args.save, 'elbo_decomposition.pth'))
    eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_latent.png'))

    for file in [open(os.path.join(args.save, 'done'), 'w'), sys.stdout]:
        print('done', file=file)

    return vae
Пример #11
0
    def warmup_generator(self, generator):
        """ Training on L1 Loss to warmup the Generator.

    Minimizing the L1 Loss will reduce the Peak Signal to Noise Ratio (PSNR)
    of the generated image from the generator.
    This trained generator is then used to bootstrap the training of the
    GAN, creating better image inputs instead of random noises.
    Args:
      generator: Model Object for the Generator
    """
        # Loading up phase parameters
        warmup_num_iter = self.settings.get("warmup_num_iter", None)
        phase_args = self.settings["train_psnr"]
        decay_params = phase_args["adam"]["decay"]
        decay_step = decay_params["step"]
        decay_factor = decay_params["factor"]

        metric = tf.keras.metrics.Mean()
        psnr_metric = tf.keras.metrics.Mean()
        tf.summary.experimental.set_step(tf.Variable(0, dtype=tf.int64))
        # Generator Optimizer
        G_optimizer = tf.optimizers.Adam(
            learning_rate=phase_args["adam"]["initial_lr"],
            beta_1=phase_args["adam"]["beta_1"],
            beta_2=phase_args["adam"]["beta_2"])
        checkpoint = tf.train.Checkpoint(
            G=generator,
            G_optimizer=G_optimizer,
            summary_step=tf.summary.experimental.get_step())

        status = utils.load_checkpoint(checkpoint, "phase_1")
        logging.debug("phase_1 status object: {}".format(status))
        previous_loss = float("inf")
        start_time = time.time()
        # Training starts
        for epoch in range(self.iterations):
            metric.reset_states()
            psnr_metric.reset_states()
            for image_lr, image_hr in self.dataset:
                step = tf.summary.experimental.get_step()
                if warmup_num_iter and step % warmup_num_iter:
                    return

                with tf.GradientTape() as tape:
                    fake = generator(image_lr)
                    loss = utils.pixel_loss(image_hr, fake)
                psnr = psnr_metric(
                    tf.reduce_mean(tf.image.psnr(fake, image_hr,
                                                 max_val=256.0)))
                gradient = tape.gradient(loss, generator.trainable_variables)
                G_optimizer.apply_gradients(
                    zip(gradient, generator.trainable_variables))
                mean_loss = metric(loss)

                if status:
                    status.assert_consumed()
                    logging.info(
                        "consumed checkpoint for phase_1 successfully")
                    status = None

                if not step % decay_step and step:  # Decay Learning Rate
                    logging.debug("Learning Rate: %f" %
                                  G_optimizer.learning_rate.numpy())
                    G_optimizer.learning_rate.assign(
                        G_optimizer.learning_rate * decay_factor)
                    logging.debug(
                        "Decayed Learning Rate by %f. Current Learning Rate %f"
                        % (decay_factor, G_optimizer.learning_rate.numpy()))
                with self.summary_writer.as_default():
                    tf.summary.scalar("warmup_loss", mean_loss, step=step)
                    tf.summary.scalar("mean_psnr", psnr, step=step)
                    step.assign_add(1)

                if not step % self.settings["print_step"]:
                    with self.summary_writer.as_default():
                        tf.summary.image(
                            "fake_image",
                            tf.cast(tf.clip_by_value(fake[:1], 0, 255),
                                    tf.uint8),
                            step=step)
                        tf.summary.image("hr_image",
                                         tf.cast(image_hr[:1], tf.uint8),
                                         step=step)

                    logging.info(
                        "[WARMUP] Epoch: {}\tBatch: {}\tGenerator Loss: {}\tPSNR: {}\tTime Taken: {} sec"
                        .format(epoch, step // epoch, mean_loss.numpy(),
                                psnr.numpy(),
                                time.time() - start_time))
                    if mean_loss < previous_loss:
                        utils.save_checkpoint(checkpoint, "phase_1")
                    previous_loss = mean_loss
                    start_time = time.time()
Пример #12
0
    def train_gan(self, generator, discriminator):
        """ Implements Training routine for ESRGAN
        Args:
          generator: Model object for the Generator
          discriminator: Model object for the Discriminator
    """
        phase_args = self.settings["train_combined"]
        decay_args = phase_args["adam"]["decay"]
        decay_factor = decay_args["factor"]
        decay_steps = decay_args["step"]
        lambda_ = phase_args["lambda"]
        hr_dimension = self.settings["dataset"]["hr_dimension"]
        eta = phase_args["eta"]
        tf.summary.experimental.set_step(tf.Variable(0, dtype=tf.int64))
        optimizer = partial(tf.optimizers.Adam,
                            learning_rate=phase_args["adam"]["initial_lr"],
                            beta_1=phase_args["adam"]["beta_1"],
                            beta_2=phase_args["adam"]["beta_2"])

        G_optimizer = optimizer()
        D_optimizer = optimizer()

        ra_gen = utils.RelativisticAverageLoss(discriminator, type_="G")
        ra_disc = utils.RelativisticAverageLoss(discriminator, type_="D")

        # The weights of generator trained during Phase #1
        # is used to initialize or "hot start" the generator
        # for phase #2 of training
        status = None
        checkpoint = tf.train.Checkpoint(
            G=generator,
            G_optimizer=G_optimizer,
            D=discriminator,
            D_optimizer=D_optimizer,
            summary_step=tf.summary.experimental.get_step())

        if not tf.io.gfile.exists(
                os.path.join(self.settings["checkpoint_path"]["phase_2"],
                             "checkpoint")):
            hot_start = tf.train.Checkpoint(
                G=generator,
                G_optimizer=G_optimizer,
                summary_step=tf.summary.experimental.get_step())
            status = utils.load_checkpoint(hot_start, "phase_1")
            # consuming variable from checkpoint
            tf.summary.experimental.get_step()

            tf.summary.experimental.set_step(tf.Variable(0, dtype=tf.int64))
        else:
            status = utils.load_checkpoint(checkpoint, "phase_2")

        logging.debug("phase status object: {}".format(status))

        gen_metric = tf.keras.metrics.Mean()
        disc_metric = tf.keras.metrics.Mean()
        psnr_metric = tf.keras.metrics.Mean()
        perceptual_loss = utils.PerceptualLoss(
            weights="imagenet",
            input_shape=[hr_dimension, hr_dimension, 3],
            loss_type=phase_args["perceptual_loss_type"])
        for epoch in range(self.iterations):
            # Resetting Metrics
            gen_metric.reset_states()
            disc_metric.reset_states()
            psnr_metric.reset_states()
            start = time.time()
            for (image_lr, image_hr) in self.dataset:
                step = tf.summary.experimental.get_step()

                # Calculating Loss applying gradients
                with tf.GradientTape() as gen_tape, tf.GradientTape(
                ) as disc_tape:
                    fake = generator(image_lr)
                    percep_loss = perceptual_loss(image_hr, fake)
                    l1_loss = utils.pixel_loss(image_hr, fake)
                    loss_RaG = ra_gen(image_hr, fake)
                    disc_loss = ra_disc(image_hr, fake)
                    gen_loss = percep_loss + lambda_ * loss_RaG + eta * l1_loss
                    disc_metric(disc_loss)
                    gen_metric(gen_loss)
                psnr = psnr_metric(
                    tf.reduce_mean(tf.image.psnr(fake, image_hr,
                                                 max_val=256.0)))
                disc_grad = disc_tape.gradient(
                    disc_loss, discriminator.trainable_variables)
                gen_grad = gen_tape.gradient(gen_loss,
                                             generator.trainable_variables)
                D_optimizer.apply_gradients(
                    zip(disc_grad, discriminator.trainable_variables))
                G_optimizer.apply_gradients(
                    zip(gen_grad, generator.trainable_variables))

                if status:
                    status.assert_consumed()
                    logging.info("consumed checkpoint successfully!")
                    status = None

                # Decaying Learning Rate
                for _step in decay_steps.copy():
                    if (step - 1) >= _step:
                        decay_steps.pop()
                        logging.debug("[Phase 2] Decayed Learing Rate by %f." %
                                      decay_factor)
                        G_optimizer.learning_rate.assign(
                            G_optimizer.learning_rate * decay_factor)
                        D_optimizer.learning_rate.assign(
                            D_optimizer.learning_rate * decay_factor)

                # Writing Summary
                with self.summary_writer.as_default():
                    tf.summary.scalar("gen_loss",
                                      gen_metric.result(),
                                      step=step)
                    tf.summary.scalar("disc_loss",
                                      disc_metric.result(),
                                      step=step)
                    tf.summary.scalar("mean_psnr", psnr, step=step)
                    step.assign_add(1)

                # Logging and Checkpointing
                if not step % self.settings["print_step"]:
                    with self.summary_writer.as_default():
                        resized_lr = tf.cast(
                            tf.clip_by_value(
                                tf.image.resize(image_lr[:1],
                                                [hr_dimension, hr_dimension],
                                                method=self.settings["dataset"]
                                                ["scale_method"]), 0, 255),
                            tf.uint8)
                        tf.summary.image("lr_image", resized_lr, step=step)
                        tf.summary.image(
                            "fake_image",
                            tf.cast(tf.clip_by_value(fake[:1], 0, 255),
                                    tf.uint8),
                            step=step)
                        tf.summary.image("hr_image",
                                         tf.cast(image_hr[:1], tf.uint8),
                                         step=step)
                    logging.info(
                        "Epoch: {}\tBatch: {}\tGen Loss: {}\tDisc Loss: {}\tPSNR: {}\tTime Taken: {} sec"
                        .format((epoch + 1),
                                step.numpy() // (epoch + 1),
                                gen_metric.result().numpy(),
                                disc_metric.result().numpy(), psnr.numpy(),
                                time.time() - start))
                    utils.save_checkpoint(checkpoint, "phase_2")
                    start = time.time()
Пример #13
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-d',
                        '--dataset',
                        default='faces',
                        type=str,
                        help='dataset name',
                        choices=['shapes', 'faces'])
    parser.add_argument('-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'laplace', 'flow'])
    parser.add_argument('-x_dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'bernoulli'])
    parser.add_argument('-n',
                        '--num-epochs',
                        default=50,
                        type=int,
                        help='number of training epochs')
    parser.add_argument('-b',
                        '--batch-size',
                        default=2048,
                        type=int,
                        help='batch size')
    parser.add_argument('-l',
                        '--learning-rate',
                        default=1e-3,
                        type=float,
                        help='learning rate')
    parser.add_argument('-z',
                        '--latent-dim',
                        default=10,
                        type=int,
                        help='size of latent dimension')
    parser.add_argument('--beta',
                        default=1,
                        type=float,
                        help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_false')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss',
                        action='store_true',
                        help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom',
                        action='store_true',
                        help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test2')
    parser.add_argument('--log_freq',
                        default=50,
                        type=int,
                        help='num iterations per log')
    parser.add_argument(
        '-problem',
        default='Climate_ORNL',
        type=str,
        choices=['HEP_SL', 'Climate_ORNL', 'Climate_C', 'Nuclear_Physics'])
    parser.add_argument('--VIB', action='store_true', help='VIB regression')
    parser.add_argument('--UQ',
                        action='store_true',
                        help='Uncertainty Quantification - likelihood')
    parser.add_argument('-name_S',
                        '--name_save',
                        default=[],
                        type=str,
                        help='name to save file')
    parser.add_argument('--classification', action='store_true')
    parser.add_argument('--Func_reg', action='store_true')

    args = parser.parse_args()

    torch.cuda.set_device(args.gpu)

    # data loader
    train_loader = setup_data_loaders(args, use_cuda=True)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
        q_dist = dist.Normal()

    # setup the likelihood distribution
    if args.x_dist == 'normal':
        x_dist = dist.Normal()
    elif args.x_dist == 'bernoulli':
        x_dist = dist.Bernoulli()
    else:
        raise ValueError('x_dist can be Normal or Bernoulli')

    vae = VAE(z_dim=args.latent_dim,
              beta=args.beta,
              use_cuda=True,
              prior_dist=prior_dist,
              q_dist=q_dist,
              x_dist=x_dist,
              x_dist_name=args.x_dist,
              include_mutinfo=not args.exclude_mutinfo,
              tcvae=args.tcvae,
              conv=args.conv,
              mss=args.mss,
              problem=args.problem,
              VIB=args.VIB,
              UQ=args.UQ,
              classification=args.classification)

    if (args.Func_reg):
        args.latent_dim2 = 4
        args.beta2 = 0.0
        prior_dist2 = dist.Normal()
        q_dist2 = dist.Normal()
        x_dist2 = dist.Normal()
        args.x_dist2 = dist.Normal()
        args.tcvae2 = False
        args.conv2 = False
        args.problem2 = 'Climate_ORNL'
        args.VIB2 = True
        args.UQ2 = False
        args.classification2 = False

        vae2 = VAE(z_dim=args.latent_dim2,
                   beta=args.beta2,
                   use_cuda=True,
                   prior_dist=prior_dist2,
                   q_dist=q_dist2,
                   x_dist=x_dist2,
                   x_dist_name=args.x_dist2,
                   include_mutinfo=not args.exclude_mutinfo,
                   tcvae=args.tcvae2,
                   conv=args.conv2,
                   mss=args.mss,
                   problem=args.problem2,
                   VIB=args.VIB2,
                   UQ=args.UQ2,
                   classification=args.classification2)

    # setup the optimizer
    #optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)
    if (args.Func_reg):
        params = list(vae.parameters()) + list(vae2.parameters())
        optimizer = optim.RMSprop(params, lr=args.learning_rate)
    else:
        optimizer = optim.RMSprop(vae.parameters(), lr=args.learning_rate)
    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []
    train_rmse = []
    train_mae = []
    train_elbo1 = []
    train_elbo2 = []
    train_elbo3 = []
    train_elbo4 = []
    train_rmse2 = []
    train_mae2 = []
    # training loop
    dataset_size = len(train_loader.dataset)
    num_iterations = len(train_loader) * args.num_epochs
    print("num_iteration", len(train_loader), args.num_epochs)
    iteration = 0
    print("likelihood function", args.x_dist, x_dist)

    train_iter = iter(train_loader)
    images = train_iter.next()

    img_max = train_loader.dataset.__getmax__()

    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    elbo_running_rmse = utils.RunningAverageMeter()
    elbo_running_mae = utils.RunningAverageMeter()
    elbo_running_mean1 = utils.RunningAverageMeter()
    elbo_running_mean2 = utils.RunningAverageMeter()
    elbo_running_mean3 = utils.RunningAverageMeter()
    elbo_running_mean4 = utils.RunningAverageMeter()
    elbo_running_rmse2 = utils.RunningAverageMeter()
    elbo_running_mae2 = utils.RunningAverageMeter()
    #plot the data to visualize

    x_test = train_loader.dataset.imgs_test
    x_train = train_loader.dataset.imgs

    def count_parameters(model):
        trainable = sum(p.numel() for p in model.parameters()
                        if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        return (trainable, total)

    while iteration < num_iterations:
        for i, xy in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            #anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            if (args.problem == 'HEP_SL'):
                x = xy[0]
                x = x.float()
                x = x.cuda()
                x = Variable(x)

                y = xy[1]
                y = y.cuda()
                y = Variable(y)

                label = xy[2]
                label = label.cuda()
                label = Variable(label)

            # Get the Training Objective
            obj, elbo, x_mean_pred, z_params1, _, _ = vae.elbo(
                x, y, label, dataset_size)
            if utils.isnan(obj).any():
                raise ValueError('NaN spotted in objective.')

            obj.mean().mul(-1).backward()
            elbo_running_mean.update(elbo.mean().data)  #[0])
            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)

                if (args.VIB):
                    if not args.classification:
                        if (args.UQ):
                            A = x_mean_pred.cpu().data.numpy()[:, :, 0]
                        else:
                            A = x_mean_pred.cpu().data.numpy()
                        B = y.cpu().data.numpy()
                    else:
                        A = x_mean_pred.cpu().data.numpy()
                        B = label.cpu().data.numpy()
                else:
                    A = x_mean_pred.cpu().data.numpy()
                    B = x.cpu().data.numpy()

                rmse = np.sqrt((np.square(A - B)).mean(axis=None))
                mae = np.abs(A - B).mean(axis=None)

                elbo_running_rmse.update(rmse)
                elbo_running_mae.update(mae)

                train_rmse.append(elbo_running_rmse.avg)
                train_mae.append(elbo_running_mae.avg)

                print(
                    '[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f) RMSE: %.4f (%.4f) MAE: %.4f (%.4f)'
                    % (iteration, time.time() - batch_time, vae.beta, vae.lamb,
                       elbo_running_mean.val, elbo_running_mean.avg,
                       elbo_running_rmse.val, elbo_running_rmse.avg,
                       elbo_running_mae.val, elbo_running_mae.avg))

                utils.save_checkpoint(
                    {
                        'state_dict': vae.state_dict(),
                        'args': args
                    }, args.save, 0)

                print("max pred:", np.max(A), "max input:", np.max(B),
                      "min pred:", np.min(A), "min input:", np.min(B))

    if (args.problem == 'HEP_SL'):
        x_test = train_loader.dataset.imgs_test
        x_test = x_test.cuda()
        y_test = train_loader.dataset.lens_p_test
        y_test = y_test.cuda()
        label_test = train_loader.dataset.label_test
        label_test = label_test.cuda()

    utils.save_checkpoint({
        'state_dict': vae.state_dict(),
        'args': args
    }, args.save, 0)
    name_save = args.name_save

    Viz_plot.Convergence_plot(train_elbo, train_rmse, train_mae, name_save,
                              args.save)
    Viz_plot.display_samples_pred_mlp(vae, x_test, y_test, label_test,
                                      args.problem, args.VIB, name_save,
                                      args.UQ, args.classification, args.save,
                                      img_max)

    # Report statistics after training
    vae.eval()
    return vae
Пример #14
0
def train(
    device, args, model, growth_model, regularization_coeffs, regularization_fns, logger
):
    optimizer = optim.Adam(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    nfef_meter = utils.RunningAverageMeter(0.93)
    nfeb_meter = utils.RunningAverageMeter(0.93)
    tt_meter = utils.RunningAverageMeter(0.93)

    full_data = (
        torch.from_numpy(
            args.data.get_data()[args.data.get_times() != args.leaveout_timepoint]
        )
        .type(torch.float32)
        .to(device)
    )

    best_loss = float("inf")
    growth_model.eval()
    end = time.time()
    for itr in range(1, args.niters + 1):
        model.train()
        optimizer.zero_grad()

        # Train
        if args.spectral_norm:
            spectral_norm_power_iteration(model, 1)

        loss = compute_loss(device, args, model, growth_model, logger, full_data)
        loss_meter.update(loss.item())

        if len(regularization_coeffs) > 0:
            # Only regularize on the last timepoint
            reg_states = get_regularization(model, regularization_coeffs)
            reg_loss = sum(
                reg_state * coeff
                for reg_state, coeff in zip(reg_states, regularization_coeffs)
                if coeff != 0
            )
            loss = loss + reg_loss
        total_time = count_total_time(model)
        nfe_forward = count_nfe(model)

        loss.backward()
        optimizer.step()

        # Eval
        nfe_total = count_nfe(model)
        nfe_backward = nfe_total - nfe_forward
        nfef_meter.update(nfe_forward)
        nfeb_meter.update(nfe_backward)
        time_meter.update(time.time() - end)
        tt_meter.update(total_time)

        log_message = (
            "Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) |"
            " NFE Forward {:.0f}({:.1f})"
            " | NFE Backward {:.0f}({:.1f})".format(
                itr,
                time_meter.val,
                time_meter.avg,
                loss_meter.val,
                loss_meter.avg,
                nfef_meter.val,
                nfef_meter.avg,
                nfeb_meter.val,
                nfeb_meter.avg,
            )
        )
        if len(regularization_coeffs) > 0:
            log_message = append_regularization_to_log(
                log_message, regularization_fns, reg_states
            )
        logger.info(log_message)

        if itr % args.val_freq == 0 or itr == args.niters:
            with torch.no_grad():
                train_eval(
                    device, args, model, growth_model, itr, best_loss, logger, full_data
                )

        if itr % args.viz_freq == 0:
            if args.data.get_shape()[0] > 2:
                logger.warning("Skipping vis as data dimension is >2")
            else:
                with torch.no_grad():
                    visualize(device, args, model, itr)
        if itr % args.save_freq == 0:
            utils.save_checkpoint(
                {
                    # 'args': args,
                    "state_dict": model.state_dict(),
                    "growth_state_dict": growth_model.state_dict(),
                },
                args.save,
                epoch=itr,
            )
        end = time.time()
    logger.info("Training has finished.")
Пример #15
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument(
        '-d',
        '--dataset',
        default='shapes',
        type=str,
        help='dataset name',
        choices=['shapes', 'faces', 'celeba', 'cars3d', '3dchairs'])
    parser.add_argument('-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'lpnorm', 'lpnested'])
    parser.add_argument('-n',
                        '--num-epochs',
                        default=50,
                        type=int,
                        help='number of training epochs')
    parser.add_argument(
        '--num-iterations',
        default=0,
        type=int,
        help='number of iterations (overrides number of epochs if >0)')
    parser.add_argument('-b',
                        '--batch-size',
                        default=2048,
                        type=int,
                        help='batch size')
    parser.add_argument('-l',
                        '--learning-rate',
                        default=1e-3,
                        type=float,
                        help='learning rate')
    parser.add_argument('-z',
                        '--latent-dim',
                        default=10,
                        type=int,
                        help='size of latent dimension')
    parser.add_argument('-p',
                        '--pnorm',
                        default=4.0 / 3.0,
                        type=float,
                        help='p value of the Lp-norm')
    parser.add_argument(
        '--pnested',
        default='',
        type=str,
        help=
        'nested list representation of the Lp-nested prior, e.g. [2.1, [ [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ], [2.2, [ [1.0], [1.0], [1.0], [1.0] ] ] ] ]'
    )
    parser.add_argument(
        '--isa',
        default='',
        type=str,
        help=
        'shorthand notation of ISA Lp-nested norm, e.g. [2.1, [(2.2, 4), (2.2, 4), (2.2, 4)]]'
    )
    parser.add_argument('--p0', default=2.0, type=float, help='p0 of ISA')
    parser.add_argument('--p1', default=2.1, type=float, help='p1 of ISA')
    parser.add_argument('--n1', default=6, type=int, help='n1 of ISA')
    parser.add_argument('--p2', default=2.1, type=float, help='p2 of ISA')
    parser.add_argument('--n2', default=6, type=int, help='n2 of ISA')
    parser.add_argument('--p3', default=2.1, type=float, help='p3 of ISA')
    parser.add_argument('--n3', default=6, type=int, help='n3 of ISA')
    parser.add_argument('--scale',
                        default=1.0,
                        type=float,
                        help='scale of LpNested distribution')
    parser.add_argument('--q-dist',
                        default='normal',
                        type=str,
                        choices=['normal', 'laplace'])
    parser.add_argument('--x-dist',
                        default='bernoulli',
                        type=str,
                        choices=['normal', 'bernoulli'])
    parser.add_argument('--beta',
                        default=1,
                        type=float,
                        help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss',
                        action='store_true',
                        help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom',
                        action='store_true',
                        help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test1')
    parser.add_argument('--id', default='1')
    parser.add_argument(
        '--seed',
        default=-1,
        type=int,
        help=
        'seed for pytorch and numpy random number generator to allow reproducibility (default/-1: use random seed)'
    )
    parser.add_argument('--log_freq',
                        default=200,
                        type=int,
                        help='num iterations per log')
    parser.add_argument('--use-mse-loss', action='store_true')
    parser.add_argument('--mse-sigma',
                        default=0.01,
                        type=float,
                        help='sigma of mean squared error loss')
    parser.add_argument('--dip', action='store_true', help='use DIP-VAE')
    parser.add_argument('--dip-type',
                        default=1,
                        type=int,
                        help='DIP type (1 or 2)')
    parser.add_argument('--lambda-od',
                        default=2.0,
                        type=float,
                        help='DIP: lambda weight off-diagonal')
    parser.add_argument('--clip',
                        default=0.0,
                        type=float,
                        help='Gradient clipping (0 disabled)')
    parser.add_argument('--test', action='store_true', help='run test')
    parser.add_argument(
        '--trainingsetsize',
        default=0,
        type=int,
        help='Subsample the trainingset (0 use original training data)')
    args = parser.parse_args()

    # initialize seeds for reproducibility
    if not args.seed == -1:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.gpu != -1:
        print('Using CUDA device {}'.format(args.gpu))
        torch.cuda.set_device(args.gpu)
        use_cuda = True
    else:
        print('CUDA disabled')
        use_cuda = False

    # data loader
    train_loader = setup_data_loaders(args.dataset,
                                      args.batch_size,
                                      use_cuda=use_cuda,
                                      len_subset=args.trainingsetsize)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
    elif args.dist == 'lpnested':
        if not args.isa == '':
            pnested = parseISA(ast.literal_eval(args.isa))
        elif not args.pnested == '':
            pnested = ast.literal_eval(args.pnested)
        else:
            pnested = parseISA([
                args.p0,
                [(args.p1, args.n1), (args.p2, args.n2), (args.p3, args.n3)]
            ])

        print('using Lp-nested prior, pnested = ({}) {}'.format(
            type(pnested), pnested))
        prior_dist = LpNestedAdapter(p=pnested, scale=args.scale)
        args.latent_dim = prior_dist.dimz()
        print('using Lp-nested prior, changed latent dimension to {}'.format(
            args.latent_dim))
    elif args.dist == 'lpnorm':
        prior_dist = LpNestedAdapter(p=[args.pnorm, [[1.0]] * args.latent_dim],
                                     scale=args.scale)

    if args.q_dist == 'normal':
        q_dist = dist.Normal()
    elif args.q_dist == 'laplace':
        q_dist = dist.Laplace()

    if args.x_dist == 'normal':
        x_dist = dist.Normal(sigma=args.mse_sigma)
    elif args.x_dist == 'bernoulli':
        x_dist = dist.Bernoulli()

    if args.dip_type == 1:
        lambda_d = 10.0 * args.lambda_od
    else:
        lambda_d = args.lambda_od

    vae = VAE(z_dim=args.latent_dim,
              use_cuda=use_cuda,
              prior_dist=prior_dist,
              q_dist=q_dist,
              x_dist=x_dist,
              include_mutinfo=not args.exclude_mutinfo,
              tcvae=args.tcvae,
              conv=args.conv,
              mss=args.mss,
              dataset=args.dataset,
              mse_sigma=args.mse_sigma,
              DIP=args.dip,
              DIP_type=args.dip_type,
              lambda_od=args.lambda_od,
              lambda_d=lambda_d)

    # setup the optimizer
    optimizer = optim.Adam([{
        'params': vae.parameters()
    }, {
        'params': prior_dist.parameters(),
        'lr': 5e-4
    }],
                           lr=args.learning_rate)

    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []

    # training loop
    dataset_size = len(train_loader.dataset)
    if args.num_iterations == 0:
        num_iterations = len(train_loader) * args.num_epochs
    else:
        num_iterations = args.num_iterations
    iteration = 0
    obj_best_snapshot = float('-inf')
    best_checkpoint_updated = False

    trainingcurve_filename = os.path.join(args.save, 'trainingcurve.csv')
    if not os.path.exists(trainingcurve_filename):
        with open(trainingcurve_filename, 'w') as fd:
            fd.write(
                'iteration,num_iterations,time,elbo_running_mean_val,elbo_running_mean_avg\n'
            )

    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    nan_detected = False
    while iteration < num_iterations and not nan_detected:
        for i, x in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            if use_cuda:
                x = x.cuda()  # async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            # do ELBO gradient and accumulate loss
            #with autograd.detect_anomaly():
            obj, elbo, logpx = vae.elbo(prior_dist,
                                        x,
                                        dataset_size,
                                        use_mse_loss=args.use_mse_loss,
                                        mse_sigma=args.mse_sigma)
            if utils.isnan(obj).any():
                print('NaN spotted in objective.')
                print('lpnested: {}'.format(prior_dist.prior.p))
                print("gradient abs max {}".format(
                    max([g.abs().max() for g in gradients])))
                #raise ValueError('NaN spotted in objective.')
                nan_detected = True
                break
            elbo_running_mean.update(elbo.mean().item())

            # save checkpoint of best ELBO
            if obj.mean().item() > obj_best_snapshot:
                obj_best_snapshot = obj.mean().item()
                best_checkpoint = {
                    'state_dict': vae.state_dict(),
                    'state_dict_prior_dist': prior_dist.state_dict(),
                    'args': args,
                    'iteration': iteration,
                    'obj': obj_best_snapshot,
                    'elbo': elbo.mean().item(),
                    'logpx': logpx.mean().item()
                }
                best_checkpoint_updated = True

            #with autograd.detect_anomaly():
            obj.mean().mul(-1).backward()

            gradients = list(
                filter(lambda p: p.grad is not None, vae.parameters()))

            if args.clip > 0:
                torch.nn.utils.clip_grad_norm_(vae.parameters(), args.clip)

            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)
                time_ = time.time() - batch_time
                print(
                    '[iteration %03d/%03d] time: %.2f \tbeta %.2f \tlambda %.2f \tobj %.4f \tlogpx %.4f training ELBO: %.4f (%.4f)'
                    % (iteration, num_iterations, time_, vae.beta, vae.lamb,
                       obj.mean().item(), logpx.mean().item(),
                       elbo_running_mean.val, elbo_running_mean.avg))

                p0, p1list = backwardsParseISA(prior_dist.prior.p)
                print('lpnested: {}, {}'.format(p0, p1list))
                print("gradient abs max {}".format(
                    max([g.abs().max() for g in gradients])))

                with open(os.path.join(args.save, 'trainingcurve.csv'),
                          'a') as fd:
                    fd.write('{},{},{},{},{}\n'.format(iteration,
                                                       num_iterations, time_,
                                                       elbo_running_mean.val,
                                                       elbo_running_mean.avg))

                if best_checkpoint_updated:
                    print(
                        'Update best checkpoint [iteration %03d] training ELBO: %.4f'
                        % (best_checkpoint['iteration'],
                           best_checkpoint['elbo']))
                    utils.save_checkpoint(best_checkpoint, args.save, 0)
                    best_checkpoint_updated = False

                vae.eval()
                prior_dist.eval()

                # plot training and test ELBOs
                if args.visdom:
                    if args.dataset == 'celeba':
                        num_channels = 3
                    else:
                        num_channels = 1
                    display_samples(vae, prior_dist, x, vis, num_channels)
                    plot_elbo(train_elbo, vis)

                if iteration % (10 * args.log_freq) == 0:
                    utils.save_checkpoint(
                        {
                            'state_dict': vae.state_dict(),
                            'state_dict_prior_dist': prior_dist.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'args': args,
                            'iteration': iteration,
                            'obj': obj.mean().item(),
                            'torch_random_state': torch.get_rng_state(),
                            'numpy_random_state': np.random.get_state()
                        },
                        args.save,
                        prefix='latest-optimizer-model-')
                    if not args.dataset == 'celeba' and not args.dataset == '3dchairs':
                        eval('plot_vs_gt_' + args.dataset)(
                            vae, train_loader.dataset,
                            os.path.join(
                                args.save,
                                'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics of best snapshot after training
    vae.load_state_dict(best_checkpoint['state_dict'])
    prior_dist.load_state_dict(best_checkpoint['state_dict_prior_dist'])

    vae.eval()
    prior_dist.eval()

    if args.dataset == 'shapes':
        data_set = dset.Shapes()
    elif args.dataset == 'faces':
        data_set = dset.Faces()
    elif args.dataset == 'cars3d':
        data_set = dset.Cars3d()
    elif args.dataset == 'celeba':
        data_set = dset.CelebA()
    elif args.dataset == '3dchairs':
        data_set = dset.Chairs()
    else:
        raise ValueError('Unknown dataset ' + str(args.dataset))

    print("loaded dataset {} of size {}".format(args.dataset, len(data_set)))

    dataset_loader = DataLoader(data_set,
                                batch_size=1000,
                                num_workers=0,
                                shuffle=False)

    logpx, dependence, information, dimwise_kl, analytical_cond_kl, elbo_marginal_entropies, elbo_joint_entropy = \
        elbo_decomposition(vae, prior_dist, dataset_loader)
    torch.save(
        {
            'args': args,
            'logpx': logpx,
            'dependence': dependence,
            'information': information,
            'dimwise_kl': dimwise_kl,
            'analytical_cond_kl': analytical_cond_kl,
            'marginal_entropies': elbo_marginal_entropies,
            'joint_entropy': elbo_joint_entropy
        }, os.path.join(args.save, 'elbo_decomposition.pth'))
    print('logpx: {:.2f}'.format(logpx))
    if not args.dataset == 'celeba' and not args.dataset == '3dchairs':
        eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset,
                                           os.path.join(
                                               args.save, 'gt_vs_latent.png'))

        metric, metric_marginal_entropies, metric_cond_entropies = eval(
            'disentanglement_metrics.mutual_info_metric_' + args.dataset)(
                vae, dataset_loader.dataset)
        torch.save(
            {
                'args': args,
                'metric': metric,
                'marginal_entropies': metric_marginal_entropies,
                'cond_entropies': metric_cond_entropies,
            }, os.path.join(args.save, 'disentanglement_metric.pth'))
        print('MIG: {:.2f}'.format(metric))

        if args.dist == 'lpnested':
            p0, p1list = backwardsParseISA(prior_dist.prior.p)
            print('p0: {}'.format(p0))
            print('p1: {}'.format(p1list))
            torch.save(
                {
                    'args': args,
                    'logpx': logpx,
                    'dependence': dependence,
                    'information': information,
                    'dimwise_kl': dimwise_kl,
                    'analytical_cond_kl': analytical_cond_kl,
                    'elbo_marginal_entropies': elbo_marginal_entropies,
                    'elbo_joint_entropy': elbo_joint_entropy,
                    'metric': metric,
                    'metric_marginal_entropies': metric_marginal_entropies,
                    'metric_cond_entropies': metric_cond_entropies,
                    'p0': p0,
                    'p1': p1list
                }, os.path.join(args.save, 'combined_data.pth'))
        else:
            torch.save(
                {
                    'args': args,
                    'logpx': logpx,
                    'dependence': dependence,
                    'information': information,
                    'dimwise_kl': dimwise_kl,
                    'analytical_cond_kl': analytical_cond_kl,
                    'elbo_marginal_entropies': elbo_marginal_entropies,
                    'elbo_joint_entropy': elbo_joint_entropy,
                    'metric': metric,
                    'metric_marginal_entropies': metric_marginal_entropies,
                    'metric_cond_entropies': metric_cond_entropies,
                }, os.path.join(args.save, 'combined_data.pth'))

        if args.dist == 'lpnested':
            if args.dataset == 'shapes':
                eval('plot_vs_gt_' + args.dataset)(
                    vae,
                    dataset_loader.dataset,
                    os.path.join(args.save, 'gt_vs_grouped_latent.png'),
                    eval_subspaces=True)

                metric_subspaces, metric_marginal_entropies_subspaces, metric_cond_entropies_subspaces = eval(
                    'disentanglement_metrics.mutual_info_metric_' +
                    args.dataset)(vae,
                                  dataset_loader.dataset,
                                  eval_subspaces=True)
                torch.save(
                    {
                        'args': args,
                        'metric': metric_subspaces,
                        'marginal_entropies':
                        metric_marginal_entropies_subspaces,
                        'cond_entropies': metric_cond_entropies_subspaces,
                    },
                    os.path.join(args.save,
                                 'disentanglement_metric_subspaces.pth'))
                print('MIG grouped by subspaces: {:.2f}'.format(
                    metric_subspaces))

                torch.save(
                    {
                        'args': args,
                        'logpx': logpx,
                        'dependence': dependence,
                        'information': information,
                        'dimwise_kl': dimwise_kl,
                        'analytical_cond_kl': analytical_cond_kl,
                        'elbo_marginal_entropies': elbo_marginal_entropies,
                        'elbo_joint_entropy': elbo_joint_entropy,
                        'metric': metric,
                        'metric_marginal_entropies': metric_marginal_entropies,
                        'metric_cond_entropies': metric_cond_entropies,
                        'metric_subspaces': metric_subspaces,
                        'metric_marginal_entropies_subspaces':
                        metric_marginal_entropies_subspaces,
                        'metric_cond_entropies_subspaces':
                        metric_cond_entropies_subspaces,
                        'p0': p0,
                        'p1': p1list
                    }, os.path.join(args.save, 'combined_data.pth'))

    return vae
Пример #16
0
def valid_model(_print,
                cfg,
                model,
                valid_loader,
                optimizer,
                epoch,
                cycle=None,
                best_metric=None,
                checkpoint=False):
    tta = cfg.INFER.TTA
    threshold = cfg.INFER.THRESHOLD
    # switch to evaluate mode
    model.eval()
    freeze_batchnorm(model)

    # valid_dice = []
    # valid_iou = []
    valid_output = []
    valid_mask = []
    valid_cls_output = []
    valid_label = []
    tbar = tqdm(valid_loader)

    with torch.no_grad():
        for i, (image, mask, label) in enumerate(tbar):
            image = image.cuda()
            mask = mask.cuda()
            if tta:
                output, cls_output = model(image)
                tta_output, tta_cls_output = model(image.flip(3))
                output = (output + tta_output.flip(3)) / 2.
                cls_output = (cls_output + tta_cls_output) / 2.
            else:
                output, cls_output = model(image)

            cls_output = torch.sigmoid(cls_output)
            valid_cls_output.append(cls_output.cpu().numpy())
            valid_label.append(label.numpy())
            valid_output.append(output.cpu())
            valid_mask.append(mask.cpu())
            # batch_dice = binary_dice_metric(output, mask,
            #     threshold).cpu()
            # valid_dice.append(batch_dice)
            # batch_iou = binary_iou_metric(output, mask,
            #     threshold).cpu()
            # valid_iou.append(batch_iou)

    valid_cls_output = np.concatenate(valid_cls_output, 0)
    valid_label = np.concatenate(valid_label, 0)
    np.save(os.path.join(cfg.DIRS.OUTPUTS, f'{cfg.EXP}_cls.npy'),
            valid_cls_output)
    np.save(os.path.join(cfg.DIRS.OUTPUTS, f'label_{cfg.DATA.FOLD}.npy'),
            valid_label)
    cls_threshold = search_threshold(valid_cls_output, valid_label)
    valid_cls_output = valid_cls_output > np.expand_dims(
        np.array(cls_threshold), 0)
    valid_f1 = [
        f1_score(valid_label[:, i], valid_cls_output[:, i]) for i in range(5)
    ]
    macro_f1 = np.average(valid_f1)

    valid_output = torch.cat(valid_output, 0)
    valid_mask = torch.cat(valid_mask, 0)
    # torch.save(valid_output,
    #     os.path.join(cfg.DIRS.OUTPUTS, f'{cfg.EXP}.pth'))
    torch.save(valid_mask,
               os.path.join(cfg.DIRS.OUTPUTS, f'mask_{cfg.DATA.FOLD}.pth'))
    # valid_dice = torch.cat(valid_dice, 0).mean(0).numpy()
    # valid_iou = torch.cat(valid_iou, 0).mean(0).numpy()
    valid_cls_output = torch.from_numpy(valid_cls_output)
    valid_cls_output_mask = torch.stack([
        valid_cls_output[:, i, ...] > th for i, th in enumerate(cls_threshold)
    ], 1)
    valid_cls_output_mask = valid_cls_output_mask.unsqueeze(-1).unsqueeze(
        -1).float()
    valid_output *= valid_cls_output_mask
    torch.save(valid_output, os.path.join(cfg.DIRS.OUTPUTS, f'{cfg.EXP}.pth'))
    valid_dice = binary_dice_metric(valid_output, valid_mask,
                                    threshold).mean(0).numpy()
    valid_iou = binary_iou_metric(valid_output, valid_mask,
                                  threshold).mean(0).numpy()
    mean_dice = np.average(valid_dice)
    mean_iou = np.average(valid_iou)
    final_score = mean_iou
    log_info = "Mean Dice: %.8f - BE: %.8f - suspicious: %.8f - HGD: %.8f - cancer: %.8f - polyp: %.8f\n"
    log_info += "Mean IoU: %.8f - BE: %.8f - suspicious: %.8f - HGD: %.8f - cancer: %.8f - polyp: %.8f\n"
    log_info += "Macro F1: %.8f - BE: %.8f - suspicious: %.8f - HGD: %.8f - cancer: %.8f - polyp: %.8f"
    _print(log_info %
           (mean_dice, valid_dice[0], valid_dice[1], valid_dice[2],
            valid_dice[3], valid_dice[4], mean_iou, valid_iou[0], valid_iou[1],
            valid_iou[2], valid_iou[3], valid_iou[4], macro_f1, valid_f1[0],
            valid_f1[1], valid_f1[2], valid_f1[3], valid_f1[4]))

    # checkpoint
    if checkpoint:
        is_best = final_score > best_metric
        best_metric = max(final_score, best_metric)
        save_dict = {
            "epoch": epoch + 1,
            "arch": cfg.EXP,
            "state_dict": model.state_dict(),
            "best_metric": best_metric,
            "optimizer": optimizer.state_dict()
        }
        if cycle is not None:
            save_dict["cycle"] = cycle
            save_filename = f"{cfg.EXP}_cycle{cycle}.pth"
        else:
            save_filename = f"{cfg.EXP}.pth"
        save_checkpoint(save_dict,
                        is_best,
                        root=cfg.DIRS.WEIGHTS,
                        filename=save_filename)
        return best_metric
Пример #17
0
    def warmup_generator(self, generator):
        """ Training on L1 Loss to warmup the Generator.

    Minimizing the L1 Loss will reduce the Peak Signal to Noise Ratio (PSNR)
    of the generated image from the generator.
    This trained generator is then used to bootstrap the training of the
    GAN, creating better image inputs instead of random noises.
    Args:
      generator: Model Object for the Generator
    """
        # Loading up phase parameters
        warmup_num_iter = self.settings.get("warmup_num_iter", None)
        phase_args = self.settings["train_psnr"]
        decay_params = phase_args["adam"]["decay"]
        decay_step = decay_params["step"]
        decay_factor = decay_params["factor"]
        total_steps = phase_args["num_steps"]
        metric = tf.keras.metrics.Mean()
        psnr_metric = tf.keras.metrics.Mean()
        # Generator Optimizer
        G_optimizer = tf.optimizers.Adam(
            learning_rate=phase_args["adam"]["initial_lr"],
            beta_1=phase_args["adam"]["beta_1"],
            beta_2=phase_args["adam"]["beta_2"])
        checkpoint = tf.train.Checkpoint(G=generator, G_optimizer=G_optimizer)

        status = utils.load_checkpoint(checkpoint, "phase_1", self.model_dir)
        logging.debug("phase_1 status object: {}".format(status))
        previous_loss = 0
        start_time = time.time()

        # Training starts

        def _step_fn(image_lr, image_hr):
            logging.debug("Starting Distributed Step")
            with tf.GradientTape() as tape:
                fake = generator.unsigned_call(image_lr)
                loss = utils.pixel_loss(image_hr,
                                        fake) * (1.0 / self.batch_size)
            psnr_metric(
                tf.reduce_mean(tf.image.psnr(fake, image_hr, max_val=256.0)))
            gen_vars = list(set(generator.trainable_variables))
            gradient = tape.gradient(loss, gen_vars)
            G_optimizer.apply_gradients(zip(gradient, gen_vars))
            mean_loss = metric(loss)
            logging.debug("Ending Distributed Step")
            return tf.cast(G_optimizer.iterations, tf.float32)

        @tf.function
        def train_step(image_lr, image_hr):
            distributed_metric = self.strategy.experimental_run_v2(
                _step_fn, args=[image_lr, image_hr])
            mean_metric = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                               distributed_metric,
                                               axis=None)
            return mean_metric

        while True:
            image_lr, image_hr = next(self.dataset)
            num_steps = train_step(image_lr, image_hr)

            if num_steps >= total_steps:
                return
            if status:
                status.assert_consumed()
                logging.info("consumed checkpoint for phase_1 successfully")
                status = None

            if not num_steps % decay_step:  # Decay Learning Rate
                logging.debug("Learning Rate: %s" %
                              G_optimizer.learning_rate.numpy)
                G_optimizer.learning_rate.assign(G_optimizer.learning_rate *
                                                 decay_factor)
                logging.debug("Decayed Learning Rate by %f."
                              "Current Learning Rate %s" %
                              (decay_factor, G_optimizer.learning_rate))
            with self.summary_writer.as_default():
                tf.summary.scalar("warmup_loss",
                                  metric.result(),
                                  step=G_optimizer.iterations)
                tf.summary.scalar("mean_psnr", psnr_metric.result(),
                                  G_optimizer.iterations)

            if not num_steps % self.settings["print_step"]:
                logging.info("[WARMUP] Step: {}\tGenerator Loss: {}"
                             "\tPSNR: {}\tTime Taken: {} sec".format(
                                 num_steps, metric.result(),
                                 psnr_metric.result(),
                                 time.time() - start_time))
                if psnr_metric.result() > previous_loss:
                    utils.save_checkpoint(checkpoint, "phase_1",
                                          self.model_dir)
                previous_loss = psnr_metric.result()
                start_time = time.time()
Пример #18
0
def main():
    args = get_config()

    # prepare path
    utils.clear_dir(args.tmp_path)
    utils.clear_dir(args.save_path)
    print('[*] Clear path')

    # choose device
    if args.gpu >= 0:
        device = torch.device('cuda:%d' % args.gpu)
        print('[*] Choose cuda:%d as device' % args.gpu)
    else:
        device = torch.device('cpu')
        print('[*] Choose cpu as device')

    # load training set and evaluation set
    train_set = PolyMRDataset(size=args.image_size, type=args.dataset)
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    valid_set = PolyMRDataset(size=args.image_size,
                              type=args.dataset,
                              set='val')
    valid_loader = DataLoader(dataset=valid_set,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=0)
    print('[*] Load datasets')
    img_size = train_set.size
    dataset_size_train = len(train_set)
    dataset_size_valid = len(valid_set)

    # setup the model
    vae = ReasonNet(args, device)
    vae = vae.to(device)
    print('[*] Load model')

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)

    # training loop
    iteration = 0
    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    # record best elbo and epoch
    best_elbo = -1e6
    best_epoch = -1
    for epoch in range(args.num_epochs):
        vae.train()
        for i, (x, t) in enumerate(train_loader):
            optimizer.zero_grad()
            x = x.view(-1, 1, img_size, img_size).to(device)
            t = t.view(-1, 1, img_size, img_size).to(device)
            obj, recon = vae(x, t, dataset_size_train)
            obj.mul(-1).backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), 10.0)
            optimizer.step()
            iteration += 1

        vae.eval()
        elbo_running_mean.reset()
        with torch.no_grad():
            for i, (x, t) in enumerate(valid_loader):
                x = x.view(-1, 1, img_size, img_size).to(device)
                t = t.view(-1, 1, img_size, img_size).to(device)
                obj, elbo, recon = vae.evaluate(x, t, dataset_size_valid)
                elbo_running_mean.update(elbo)

            avg_elbo = elbo_running_mean.get_avg()['elbo']
            if avg_elbo > best_elbo:
                best_epoch = epoch
                utils.save_checkpoint(
                    {
                        'state_dict': vae.state_dict(),
                        'args': args,
                        'epoch': epoch
                    }, args.save_path)
                best_elbo = avg_elbo

            if epoch % args.log_freq == 0:
                elbo_running_mean.log(epoch, args.num_epochs, best_epoch)
Пример #19
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    my_whole_seed = 111
    random.seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.manual_seed(my_whole_seed)
    torch.cuda.manual_seed_all(my_whole_seed)
    torch.cuda.manual_seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(my_whole_seed)

    for kk_time in range(args.seedstart, args.seedend):
        args.seed = kk_time
        args.result = args.result + str(args.seed)

        # create model
        model = models.__dict__[args.arch](low_dim=args.low_dim,
                                           multitask=args.multitask,
                                           showfeature=args.showfeature,
                                           args=args)
        #
        # from models.Gresnet import ResNet18
        # model = ResNet18(low_dim=args.low_dim, multitask=args.multitask)
        model = torch.nn.DataParallel(model).cuda()

        # Data loading code
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])
        # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)),
        #                           transforms.RandomHorizontalFlip(p=0.5),
        #                           get_color_distortion(s=1),
        #                           transforms.Lambda(lambda x: gaussian_blur(x)),
        #                           transforms.ToTensor(),
        #                           normalize])
        # aug = transforms.Compose([transforms.RandomRotation(60),
        #                           transforms.RandomResizedCrop(224, scale=(0.6, 1.)),
        #                           transforms.RandomGrayscale(p=0.2),
        #                           transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        #                           transforms.RandomHorizontalFlip(),
        #                           transforms.ToTensor(),
        #                             normalize])
        aug_test = transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(), normalize])

        # dataset
        import datasets.fundus_kaggle_dr as medicaldata
        train_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug,
                                                 train=True,
                                                 args=args)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=8,
            drop_last=True if args.multiaug else False,
            worker_init_fn=random.seed(my_whole_seed))

        valid_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug_test,
                                                 train=False,
                                                 test_type="amd",
                                                 args=args)
        val_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))
        valid_dataset_gon = medicaldata.traindataset(root=args.data,
                                                     transform=aug_test,
                                                     train=False,
                                                     test_type="gon",
                                                     args=args)
        val_loader_gon = torch.utils.data.DataLoader(
            valid_dataset_gon,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))
        valid_dataset_pm = medicaldata.traindataset(root=args.data,
                                                    transform=aug_test,
                                                    train=False,
                                                    test_type="pm",
                                                    args=args)
        val_loader_pm = torch.utils.data.DataLoader(
            valid_dataset_pm,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=8,
            worker_init_fn=random.seed(my_whole_seed))

        # define lemniscate and loss function (criterion)
        ndata = train_dataset.__len__()

        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()
        local_lemniscate = None

        if args.multitaskposrot:
            print("running multi task with positive")
            criterion = BatchCriterionRot(1, 0.1, args.batch_size, args).cuda()
        elif args.domain:
            print("running domain with four types--unify ")
            from lib.BatchAverageFour import BatchCriterionFour
            # criterion = BatchCriterionTriple(1, 0.1, args.batch_size, args).cuda()
            criterion = BatchCriterionFour(1, 0.1, args.batch_size,
                                           args).cuda()
        elif args.multiaug:
            print("running multi task")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        else:
            criterion = nn.CrossEntropyLoss().cuda()

        if args.multitask:
            cls_criterion = nn.CrossEntropyLoss().cuda()
        else:
            cls_criterion = None

        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                lemniscate = checkpoint['lemniscate']
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if args.evaluate:
            knn_num = 100
            auc, acc, precision, recall, f1score = kNN(args, model, lemniscate,
                                                       train_loader,
                                                       val_loader, knn_num,
                                                       args.nce_t, 2)
            return

        # mkdir result folder and tensorboard
        os.makedirs(args.result, exist_ok=True)
        writer = SummaryWriter("runs/" + str(args.result.split("/")[-1]))
        writer.add_text('Text', str(args))

        # copy code
        import shutil, glob
        source = glob.glob("*.py")
        source += glob.glob("*/*.py")
        os.makedirs(args.result + "/code_file", exist_ok=True)
        for file in source:
            name = file.split("/")[0]
            if name == file:
                shutil.copy(file, args.result + "/code_file/")
            else:
                os.makedirs(args.result + "/code_file/" + name, exist_ok=True)
                shutil.copy(file, args.result + "/code_file/" + name)

        for epoch in range(args.start_epoch, args.epochs):
            lr = adjust_learning_rate(optimizer, epoch, args, [100, 200])
            writer.add_scalar("lr", lr, epoch)

            # # train for one epoch
            loss = train(train_loader, model, lemniscate, local_lemniscate,
                         criterion, cls_criterion, optimizer, epoch, writer)
            writer.add_scalar("train_loss", loss, epoch)

            # gap_int = 10
            # if (epoch) % gap_int == 0:
            #     knn_num = 100
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader, knn_num, args.nce_t, 2)
            #     writer.add_scalar("test_auc", auc, epoch)
            #     writer.add_scalar("test_acc", acc, epoch)
            #     writer.add_scalar("test_precision", precision, epoch)
            #     writer.add_scalar("test_recall", recall, epoch)
            #     writer.add_scalar("test_f1score", f1score, epoch)
            #
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader_gon,
            #                                                knn_num, args.nce_t, 2)
            #     writer.add_scalar("gon/test_auc", auc, epoch)
            #     writer.add_scalar("gon/test_acc", acc, epoch)
            #     writer.add_scalar("gon/test_precision", precision, epoch)
            #     writer.add_scalar("gon/test_recall", recall, epoch)
            #     writer.add_scalar("gon/test_f1score", f1score, epoch)
            #     auc, acc, precision, recall, f1score = kNN(args, model, lemniscate, train_loader, val_loader_pm,
            #                                                knn_num, args.nce_t, 2)
            #     writer.add_scalar("pm/test_auc", auc, epoch)
            #     writer.add_scalar("pm/test_acc", acc, epoch)
            #     writer.add_scalar("pm/test_precision", precision, epoch)
            #     writer.add_scalar("pm/test_recall", recall, epoch)
            #     writer.add_scalar("pm/test_f1score", f1score, epoch)

            # save checkpoint
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'lemniscate': lemniscate,
                    'optimizer': optimizer.state_dict(),
                },
                filename=args.result + "/fold" + str(args.seedstart) +
                "-epoch-" + str(epoch) + ".pth.tar")
Пример #20
0
    data, target, meta = next(iter(train_loader))
    step_loss, step_precision = train_triplet_step(data, target, model, device,
                                                   optimizer, miner)

    print('Train Step: {} Precision@1: {:.4f}\tLoss: {:.6f}'.format(
        step, step_precision, step_loss),
          flush=True)

    if step % args.val_freq == 0:
        total_loss, acc_dict, embedding_list, target_list = representation(
            model, device, validation_loader)
        lr_scheduler.step(total_loss)
        es(total_loss, step, model.state_dict(), output_dir / 'model.pt')

        save_checkpoint(
            model, optimizer, lr_scheduler,
            train_loader.sampler.state_dict(train_loader._infinite_iterator),
            step + 1, es, torch.random.get_rng_state())

_, acc_dict, embedding_list, target_list = representation(
    model, device, test_loader)
_, acc_dict_aug, embedding_list_aug, target_list_aug = representation(
    model, device, test_loader_aug)

results = {}
acc_calc = AccuracyCalculator()
for m, embedding, target in zip(['unaug', 'aug'],
                                [embedding_list, embedding_list_aug],
                                [target_list, target_list_aug]):
    results[m] = {}
    for grp in np.unique(target):
        target_bin = target == grp
def main():
    global best_test_bpd

    last_checkpoints = []
    lipschitz_constants = []
    ords = []
    alphas = []
    betas = []
    concat_eta1 = []
    concat_eta2 = []
    concat_K1 = []
    concat_K2 = []

    # if args.resume:
    #     validate(args.begin_epoch - 1, model, ema)
    for epoch in range(args.begin_epoch, args.nepochs):

        logger.info('Current LR {}'.format(optimizer.param_groups[0]['lr']))

        train(epoch, model)
        lipschitz_constants.append(get_lipschitz_constants(model))
        logger.info('Lipsh: {}'.format(pretty_repr(lipschitz_constants[-1])))

        if args.learn_p:
            ords.append(get_ords(model))
            logger.info('Order: {}'.format(pretty_repr(ords[-1])))

        if args.act == 'LeakyLSwish':
            alpha, beta = get_activation_params(model)
            alphas.append(alpha)
            betas.append(beta)

            logger.info('alphas: {}'.format(pretty_repr(alphas[-1])))
            logger.info('betas: {}'.format(pretty_repr(betas[-1])))

        if args.learnable_concat:
            eta1, eta2, K1, K2 = get_learnable_params(model)
            concat_eta1.append(eta1)
            concat_eta2.append(eta2)
            concat_K1.append(K1)
            concat_K2.append(K2)

            logger.info('eta1: {}'.format(pretty_repr(concat_eta1[-1])))
            logger.info('eta2: {}'.format(pretty_repr(concat_eta2[-1])))
            logger.info('K1: {}'.format(pretty_repr(concat_K1[-1])))
            logger.info('K2: {}'.format(pretty_repr(concat_K2[-1])))

        if args.ema_val:
            test_bpd = validate(epoch, model, ema)
        else:
            test_bpd = validate(epoch, model)

        if args.scheduler and scheduler is not None:
            scheduler.step()

        if test_bpd < best_test_bpd:
            best_test_bpd = test_bpd
            utils.save_checkpoint(
                {
                    'state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'args': args,
                    'ema': ema,
                    'test_bpd': test_bpd,
                },
                os.path.join(args.save, 'models'),
                epoch,
                last_checkpoints,
                num_checkpoints=5)

        torch.save(
            {
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'args': args,
                'ema': ema,
                'test_bpd': test_bpd,
            }, os.path.join(args.save, 'models', 'most_recent.pth'))
            print('')

            utils.makedirs(os.path.dirname(fig_filename))
            plt.savefig(fig_filename)

            #plt.ion()
            #plt.show()

            #plt.pause(0.1)
            plt.close()

            #utils.save_checkpoint({'state_dict': genGen.state_dict()}, os.path.join(args.save, 'myModels3'), itr)
            #utils.save_checkpoint({'state_dict': genGen.state_dict()}, os.path.join(args.save, 'myModels'), args.niters2)

            utils.save_checkpoint({'state_dict': genGen.state_dict()},
                                  os.path.join(args.save, 'myModels'),
                                  args.niters2)

        #loss_meter.update(loss.item())
        #logpz_meter.update(logpz.item())

        loss2_meter.update(lossGen.item())

        #delta_logp_meter.update(delta_logp.item())

        #loss.backward()
        #lossGen.backward()

        #lossGen.backward(create_graph=True)
        lossGen.backward()
Пример #23
0
def main():
    torch.autograd.set_detect_anomaly(True)
    logger.info('Start to declare training variable')
    if torch.cuda.is_available():
        cfg.device = torch.device("cuda")
        torch.cuda.set_device(cfg.local_rank)
    else:
        cfg.device = torch.device("cpu")
    logger.info('Session will be ran in device: [%s]' % cfg.device)
    start_epoch = 0
    best_acc = 0.

    logger.info('Start to prepare data')
    # get transformers
    # train_transform is for data perturbation
    train_transform = transforms.get(train=True)
    # test_transform is for evaluation
    test_transform = transforms.get(train=False)
    # reduced_transform is for original training data
    reduced_transform = get_reduced_transform(cfg.tfm_resize, cfg.tfm_size,
                                              cfg.tfm_blur, cfg.tfm_means,
                                              cfg.tfm_stds,
                                              cfg.tfm_adaptive_thresholding)
    # get datasets
    # each head should have its own trainset
    train_splits = dict(cifar100=[['train', 'test']],
                        image_folder_wrapper=[['train']],
                        stl10=[['train+unlabeled', 'test'], ['train', 'test']])
    test_splits = dict(cifar100=['train', 'test'],
                       image_folder_wrapper=['test'],
                       stl10=['train', 'test'])
    # instance dataset for each head
    # otrainset: original trainset
    otrainset = [
        ConcatDataset([
            datasets.get(split=split, transform=reduced_transform)
            for split in train_splits[cfg.dataset][hidx]
        ]) for hidx in range(len(train_splits[cfg.dataset]))
    ]
    # ptrainset: perturbed trainset
    ptrainset = [
        ConcatDataset([
            datasets.get(split=split, transform=train_transform)
            for split in train_splits[cfg.dataset][hidx]
        ]) for hidx in range(len(train_splits[cfg.dataset]))
    ]
    # testset
    testset = ConcatDataset([
        datasets.get(split=split, transform=test_transform)
        for split in test_splits[cfg.dataset]
    ])
    # declare data loaders for testset only
    test_loader = DataLoader(testset,
                             batch_size=cfg.batch_size,
                             shuffle=False,
                             num_workers=cfg.num_workers)

    logger.info('Start to build model')
    net = networks.get()
    criterion = PUILoss(cfg.pica_lamda, cfg.pica_target, cfg.pica_iic)
    optimizer = optimizers.get(
        params=[val for _, val in net.trainable_parameters().items()])
    lr_handler = lr_policy.get()

    # load session if checkpoint is provided
    if cfg.resume:
        assert os.path.exists(cfg.resume), "Resume file not found"
        ckpt = torch.load(cfg.resume)
        logger.info('Start to resume session for file: [%s]' % cfg.resume)
        net.load_state_dict(ckpt['net'])
        best_acc = ckpt['acc']
        start_epoch = ckpt['epoch']

    # move modules to target device
    if int(os.environ["WORLD_SIZE"]) > 1:
        dist.init_process_group(backend="nccl", init_method="env://")
    print("world size: {}".format(os.environ["WORLD_SIZE"]))
    print("rank: {}".format(cfg.local_rank))
    synchronize()

    criterion = criterion.to(cfg.device)
    net = net.to(cfg.device)

    if int(os.environ["WORLD_SIZE"]) > 1:
        net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
        net = torch.nn.parallel.DistributedDataParallel(
            net,
            device_ids=[cfg.local_rank],
            find_unused_parameters=True,
            output_device=cfg.local_rank).cuda()

    # Only rank 0 needs a SummaryWriter
    if cfg.local_rank == 0:
        # tensorboard writer
        writer = SummaryWriter(cfg.debug, log_dir=cfg.tfb_dir)
    else:
        writer = None

    # start training
    lr = cfg.base_lr
    epoch = start_epoch

    logger.info('Start to evaluate after %d epoch of training' % epoch)
    acc = evaluate(net, test_loader, writer, epoch)

    if not cfg.debug and cfg.local_rank == 0:
        # save checkpoint
        is_best = acc > best_acc
        best_acc = max(best_acc, acc)
        save_checkpoint(
            {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'acc': acc,
                'epoch': epoch
            },
            is_best=is_best)

    while lr > 0 and epoch < cfg.max_epochs:

        lr = lr_handler.update(epoch, optimizer)

        logger.info('Start to train at %d epoch with learning rate %.5f' %
                    (epoch, lr))
        train(epoch, net, otrainset, ptrainset, optimizer, criterion, writer)

        epoch += 1

        logger.info('Start to evaluate after %d epoch of training' % epoch)
        acc = evaluate(net, test_loader, writer, epoch)

        if not cfg.debug and cfg.local_rank == 0:
            writer.add_scalar('Train/Learing_Rate', lr, epoch)
            # save checkpoint
            is_best = acc > best_acc
            best_acc = max(best_acc, acc)
            save_checkpoint(
                {
                    'net': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'acc': acc,
                    'epoch': epoch
                },
                is_best=is_best)

    logger.info('Done')
Пример #24
0
def main():
    # parse command line arguments
    parser = argparse.ArgumentParser(description="parse args")
    parser.add_argument('-d', '--dataset', default='shapes', type=str, help='dataset name',
        choices=['shapes', 'faces'])
    parser.add_argument('-dist', default='normal', type=str, choices=['normal', 'laplace', 'flow'])
    parser.add_argument('-n', '--num-epochs', default=50, type=int, help='number of training epochs')
    parser.add_argument('-b', '--batch-size', default=2048, type=int, help='batch size')
    parser.add_argument('-l', '--learning-rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-z', '--latent-dim', default=10, type=int, help='size of latent dimension')
    parser.add_argument('--beta', default=1, type=float, help='ELBO penalty term')
    parser.add_argument('--tcvae', action='store_true')
    parser.add_argument('--exclude-mutinfo', action='store_true')
    parser.add_argument('--beta-anneal', action='store_true')
    parser.add_argument('--lambda-anneal', action='store_true')
    parser.add_argument('--mss', action='store_true', help='use the improved minibatch estimator')
    parser.add_argument('--conv', action='store_true')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--visdom', action='store_true', help='whether plotting in visdom is desired')
    parser.add_argument('--save', default='test1')
    parser.add_argument('--log_freq', default=200, type=int, help='num iterations per log')
    args = parser.parse_args()

    # torch.cuda.set_device(args.gpu)

    # data loader
    train_loader = setup_data_loaders(args, use_cuda=True)

    # setup the VAE
    if args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = FactorialNormalizingFlow(dim=args.latent_dim, nsteps=32)
        q_dist = dist.Normal()

    vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist,
        include_mutinfo=not args.exclude_mutinfo, tcvae=args.tcvae, conv=args.conv, mss=args.mss)

    # setup the optimizer
    optimizer = optim.Adam(vae.parameters(), lr=args.learning_rate)

    # setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom(env=args.save, port=4500)

    train_elbo = []

    # training loop
    dataset_size = len(train_loader.dataset)
    num_iterations = len(train_loader) * args.num_epochs
    iteration = 0
    # initialize loss accumulator
    elbo_running_mean = utils.RunningAverageMeter()
    while iteration < num_iterations:
        for i, x in enumerate(train_loader):
            iteration += 1
            batch_time = time.time()
            vae.train()
            anneal_kl(args, vae, iteration)
            optimizer.zero_grad()
            # transfer to GPU
            x = x.cuda(async=True)
            # wrap the mini-batch in a PyTorch Variable
            x = Variable(x)
            # do ELBO gradient and accumulate loss
            obj, elbo = vae.elbo(x, dataset_size)
            if utils.isnan(obj).any():
                raise ValueError('NaN spotted in objective.')
            obj.mean().mul(-1).backward()
            print("obj value: ", obj.mean().mul(-1).cpu())
            elbo_running_mean.update(elbo.mean().item())
            optimizer.step()

            # report training diagnostics
            if iteration % args.log_freq == 0:
                train_elbo.append(elbo_running_mean.avg)
                print('[iteration %03d] time: %.2f \tbeta %.2f \tlambda %.2f training ELBO: %.4f (%.4f)' % (
                    iteration, time.time() - batch_time, vae.beta, vae.lamb,
                    elbo_running_mean.val, elbo_running_mean.avg))

                vae.eval()

                # plot training and test ELBOs
                if args.visdom:
                    display_samples(vae, x, vis)
                    plot_elbo(train_elbo, vis)

                utils.save_checkpoint({
                    'state_dict': vae.state_dict(),
                    'args': args}, args.save, 0)
                eval('plot_vs_gt_' + args.dataset)(vae, train_loader.dataset,
                    os.path.join(args.save, 'gt_vs_latent_{:05d}.png'.format(iteration)))

    # Report statistics after training
    vae.eval()
    utils.save_checkpoint({
        'state_dict': vae.state_dict(),
        'args': args}, args.save, 0)
    dataset_loader = DataLoader(train_loader.dataset, batch_size=10, num_workers=1, shuffle=False)
    logpx, dependence, information, dimwise_kl, analytical_cond_kl, marginal_entropies, joint_entropy = \
        elbo_decomposition(vae, dataset_loader)
    torch.save({
        'logpx': logpx,
        'dependence': dependence,
        'information': information,
        'dimwise_kl': dimwise_kl,
        'analytical_cond_kl': analytical_cond_kl,
        'marginal_entropies': marginal_entropies,
        'joint_entropy': joint_entropy
    }, os.path.join(args.save, 'elbo_decomposition.pth'))
    eval('plot_vs_gt_' + args.dataset)(vae, dataset_loader.dataset, os.path.join(args.save, 'gt_vs_latent.png'))
    return vae
Пример #25
0
    def train(self,
              nets,
              criterions,
              optimizers,
              train_loader,
              test_loader,
              logs=None,
              **kwargs):
        import time
        import os

        print("manual seed : %d" % self.args.manualSeed)

        for epoch in range(self.args.trainer.start_epoch,
                           self.args.trainer.epochs + 1):
            print("epoch %d" % epoch)
            start_time = time.time()

            for optimizer, model_args in zip(optimizers, self.args.models):
                utils.adjust_learning_rate(optimizer, epoch,
                                           model_args.optim.gammas,
                                           model_args.optim.schedule,
                                           model_args.optim.args.lr)

            kwargs = {} if kwargs is None else kwargs
            kwargs.update({
                "_trainer": self,
                "_train_loader": train_loader,
                "_test_loader": test_loader,
                "_nets": nets,
                "_criterions": criterions,
                "_optimizers": optimizers,
                "_epoch": epoch,
                "_logs": logs,
                "_args": self.args
            })

            # train for one epoch
            self.train_on_dataset(train_loader, nets, criterions, optimizers,
                                  epoch, logs, **kwargs)
            # evaluate on validation set
            self.validate_on_dataset(test_loader, nets, criterions, epoch,
                                     logs, **kwargs)

            # print log
            for i, log in enumerate(logs.net):
                print(
                    "  net{0}    loss :train={1:.3f}, test={2:.3f}    acc :train={3:.3f}, test ={4:.3f}"
                    .format(i, log["epoch_log"][epoch]["train_loss"],
                            log["epoch_log"][epoch]["test_loss"],
                            log["epoch_log"][epoch]["train_accuracy"],
                            log["epoch_log"][epoch]["test_accuracy"]))

            if epoch % self.args.trainer.saving_interval == 0:
                ckpt_dir = os.path.join(self.args.trainer.base_dir,
                                        "checkpoint")
                utils.save_checkpoint(nets, optimizers, epoch, ckpt_dir)

            logs.save(self.args.trainer.base_dir + r"log/")

            elapsed_time = time.time() - start_time
            print("  elapsed_time:{0:.3f}[sec]".format(elapsed_time))

            if "_callback" in kwargs:
                kwargs["_callback"](**kwargs)

        return
Пример #26
0
def main():

    global args, best_prec1
    args = parser.parse_args()

    #  init seed
    my_whole_seed = 222
    random.seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.manual_seed(my_whole_seed)
    torch.cuda.manual_seed_all(my_whole_seed)
    torch.cuda.manual_seed(my_whole_seed)
    np.random.seed(my_whole_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(my_whole_seed)

    for kk_time in range(args.seedstart, args.seedstart + 1):
        args.seed = kk_time
        args.result = args.result + str(args.seed)

        # create model
        model = models.__dict__[args.arch](low_dim=args.low_dim,
                                           multitask=args.multitask,
                                           showfeature=args.showfeature,
                                           domain=args.domain,
                                           args=args)
        model = torch.nn.DataParallel(model).cuda()
        print('Number of learnable params',
              get_learnable_para(model) / 1000000., " M")

        # Data loading code
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        aug = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])
        # aug = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.08, 1.), ratio=(3 / 4, 4 / 3)),
        #                           transforms.RandomHorizontalFlip(p=0.5),
        #                           get_color_distortion(s=1),
        #                           transforms.Lambda(lambda x: gaussian_blur(x)),
        #                           transforms.ToTensor(),
        #                           normalize])
        aug_test = transforms.Compose(
            [transforms.Resize((224, 224)),
             transforms.ToTensor(), normalize])

        # load dataset
        # import datasets.fundus_amd_syn_crossvalidation as medicaldata
        import datasets.fundus_amd_syn_crossvalidation_ind as medicaldata
        train_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug,
                                                 train=True,
                                                 args=args)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
            drop_last=True if args.multiaug else False,
            worker_init_fn=random.seed(my_whole_seed))

        valid_dataset = medicaldata.traindataset(root=args.data,
                                                 transform=aug_test,
                                                 train=False,
                                                 args=args)
        val_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4,
            worker_init_fn=random.seed(my_whole_seed))

        # define lemniscate and loss function (criterion)
        ndata = train_dataset.__len__()

        lemniscate = LinearAverage(args.low_dim, ndata, args.nce_t,
                                   args.nce_m).cuda()

        if args.multitaskposrot:
            cls_criterion = nn.CrossEntropyLoss().cuda()
        else:
            cls_criterion = None

        if args.multitaskposrot:
            print("running multi task with miccai")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        elif args.synthesis:
            print("running synthesis")
            criterion = BatchCriterionFour(1, 0.1, args.batch_size,
                                           args).cuda()
        elif args.multiaug:
            print("running cvpr")
            criterion = BatchCriterion(1, 0.1, args.batch_size, args).cuda()
        else:
            criterion = nn.CrossEntropyLoss().cuda()

        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     weight_decay=args.weight_decay)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                lemniscate = checkpoint['lemniscate']
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if args.evaluate:
            knn_num = 100
            auc, acc, precision, recall, f1score = kNN(args, model, lemniscate,
                                                       train_loader,
                                                       val_loader, knn_num,
                                                       args.nce_t, 2)
            f = open("savemodels/result.txt", "a+")
            f.write("auc: %.4f\n" % (auc))
            f.write("acc: %.4f\n" % (acc))
            f.write("pre: %.4f\n" % (precision))
            f.write("recall: %.4f\n" % (recall))
            f.write("f1score: %.4f\n" % (f1score))
            f.close()
            return

        # mkdir result folder and tensorboard
        os.makedirs(args.result, exist_ok=True)
        writer = SummaryWriter("runs/" + str(args.result.split("/")[-1]))
        writer.add_text('Text', str(args))

        # copy code
        import shutil, glob
        source = glob.glob("*.py")
        source += glob.glob("*/*.py")
        os.makedirs(args.result + "/code_file", exist_ok=True)
        for file in source:
            name = file.split("/")[0]
            if name == file:
                shutil.copy(file, args.result + "/code_file/")
            else:
                os.makedirs(args.result + "/code_file/" + name, exist_ok=True)
                shutil.copy(file, args.result + "/code_file/" + name)

        for epoch in range(args.start_epoch, args.epochs):
            lr = adjust_learning_rate(optimizer, epoch, args, [1000, 2000])
            writer.add_scalar("lr", lr, epoch)

            # # train for one epoch
            loss = train(train_loader, model, lemniscate, criterion,
                         cls_criterion, optimizer, epoch, writer)
            writer.add_scalar("train_loss", loss, epoch)

            # save checkpoint
            if epoch % 200 == 0 or (epoch in [1600, 1800, 2000]):
                auc, acc, precision, recall, f1score = kNN(
                    args, model, lemniscate, train_loader, val_loader, 100,
                    args.nce_t, 2)
                # save to txt
                writer.add_scalar("test_auc", auc, epoch)
                writer.add_scalar("test_acc", acc, epoch)
                writer.add_scalar("test_precision", precision, epoch)
                writer.add_scalar("test_recall", recall, epoch)
                writer.add_scalar("test_f1score", f1score, epoch)
                f = open(args.result + "/result.txt", "a+")
                f.write("epoch " + str(epoch) + "\n")
                f.write("auc: %.4f\n" % (auc))
                f.write("acc: %.4f\n" % (acc))
                f.write("pre: %.4f\n" % (precision))
                f.write("recall: %.4f\n" % (recall))
                f.write("f1score: %.4f\n" % (f1score))
                f.close()
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'lemniscate': lemniscate,
                        'optimizer': optimizer.state_dict(),
                    },
                    filename=args.result + "/fold" + str(args.seedstart) +
                    "-epoch-" + str(epoch) + ".pth.tar")
Пример #27
0
        if gen_plot is not None:
            plot_model_task(model, gen_plot, epoch, wd)

        update_learning_rate(opt,
                             decay_rate=0.999,
                             lowest=args.learning_rate / 10)

        # Update the best objective value and checkpoint the model.
        is_best = False
        if val_obj['nll'] < best_obj:
            best_obj = val_obj['nll']
            is_best = True
        save_checkpoint(wd, {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_acc_top1': best_obj,
            'optimizer': opt.state_dict()
        },
                        is_best=is_best)

else:
    # Load saved model.
    load_dict = torch.load(wd.file('model_best.pth.tar', exists=True))
    model.load_state_dict(load_dict['state_dict'])

# Perform final quality validation
means = {}
for _ in range(10):
    loss_dict = validate(gen_test, model, losses, mode='mean')
    for name in loss_dict:
        if name not in means:
Пример #28
0
def main(cfg, config_name):
    """
    Main training function: after preparing the data loaders, model, optimizer, and trainer,
    start with the training process.

    Args:
        cfg (dict): current configuration parameters
        config_name (str): path to the config file
    """

    # Create the output dir if it does not exist
    if not os.path.exists(cfg['misc']['log_dir']):
        os.makedirs(cfg['misc']['log_dir'])

    # Initialize the model
    model = config.get_model(cfg)
    model = model.cuda()

    # Get data loader
    train_loader = make_data_loader(cfg, phase='train')
    val_loader = make_data_loader(cfg, phase='val')

    # Log directory
    dataset_name = cfg["data"]["dataset"]

    now = datetime.now().strftime("%y_%m_%d-%H_%M_%S_%f")
    now += "__Method_" + str(cfg['method']['backbone'])
    now += "__Pretrained_" if cfg['network']['use_pretrained'] and cfg[
        'network']['pretrained_path'] else ''
    if cfg['method']['flow']: now += "__Flow_"
    if cfg['method']['ego_motion']: now += "__Ego_"
    if cfg['method']['semantic']: now += "__Sem_"
    now += "__Rem_Ground_" if cfg['data']['remove_ground'] else ''
    now += "__VoxSize_" + str(cfg['misc']["voxel_size"])
    now += "__Pts_" + str(cfg['misc']["num_points"])
    path2log = os.path.join(cfg['misc']['log_dir'], "logs_" + dataset_name,
                            now)

    logger, checkpoint_dir = prepare_logger(cfg, path2log)
    tboard_logger = SummaryWriter(path2log)

    # Output number of model parameters
    logger.info("Parameter Count: {:d}".format(n_model_parameters(model)))

    # Output torch and cuda version
    logger.info('Torch version: {}'.format(torch.__version__))
    logger.info('CUDA version: {}'.format(torch.version.cuda))

    # Save config file that was used for this experiment
    with open(os.path.join(path2log,
                           config_name.split(os.sep)[-1]), 'w') as outfile:
        yaml.dump(cfg, outfile, default_flow_style=False, allow_unicode=True)

    # Get optimizer and trainer
    optimizer = config.get_optimizer(cfg, model)
    scheduler = config.get_scheduler(cfg, optimizer)

    # Parameters determining the saving and validation interval (if positive denotes iteration if negative epoch)
    stat_interval = cfg['train']['stat_interval']
    stat_interval = stat_interval if stat_interval > 0 else abs(
        stat_interval * len(train_loader))

    chkpt_interval = cfg['train']['chkpt_interval']
    chkpt_interval = chkpt_interval if chkpt_interval > 0 else abs(
        chkpt_interval * len(train_loader))

    val_interval = cfg['train']['val_interval']
    val_interval = val_interval if val_interval > 0 else abs(val_interval *
                                                             len(train_loader))

    # if not a pretrained model epoch and iterations should be -1
    metric_val_best = np.inf
    running_metrics = {}
    running_losses = {}
    epoch_it = -1
    total_it = -1

    # Load the pretrained weights
    if cfg['network']['use_pretrained'] and cfg['network']['pretrained_path']:
        model, optimizer, scheduler, epoch_it, total_it, metric_val_best = load_checkpoint(
            model,
            optimizer,
            scheduler,
            filename=cfg['network']['pretrained_path'])

        # Find previous tensorboard files and copy them
        tb_files = glob.glob(
            os.sep.join(cfg['network']['pretrained_path'].split(os.sep)[:-1]) +
            '/events.*')
        for tb_file in tb_files:
            shutil.copy(tb_file,
                        os.path.join(path2log,
                                     tb_file.split(os.sep)[-1]))

    # Initialize the trainer
    device = torch.device('cuda' if (
        torch.cuda.is_available() and cfg['misc']['use_gpu']) else 'cpu')
    trainer = config.get_trainer(cfg, model, device)
    acc_iter_size = cfg['train']['acc_iter_size']

    # Training loop
    while epoch_it < cfg['train']['max_epoch']:
        epoch_it += 1
        lr = scheduler.get_last_lr()
        logger.info('Training epoch: {}, LR: {} '.format(epoch_it, lr))
        gc.collect()

        train_loader_iter = train_loader.__iter__()
        start = time.time()
        tbar = tqdm(total=len(train_loader) // acc_iter_size, ncols=100)

        for it in range(len(train_loader) // acc_iter_size):
            optimizer.zero_grad()
            total_it += 1
            batch_metrics = {}
            batch_losses = {}

            for iter_idx in range(acc_iter_size):

                batch = train_loader_iter.next()

                dict_all_to_device(batch, device)
                losses, metrics, total_loss = trainer.train_step(batch)

                total_loss.backward()

                # Save the running metrics and losses
                if not batch_metrics:
                    batch_metrics = copy.deepcopy(metrics)
                else:
                    for key, value in metrics.items():
                        batch_metrics[key] += value

                if not batch_losses:
                    batch_losses = copy.deepcopy(losses)
                else:
                    for key, value in losses.items():
                        batch_losses[key] += value

            # Compute the mean value of the metrics and losses of the batch
            for key, value in batch_metrics.items():
                batch_metrics[key] = value / acc_iter_size

            for key, value in batch_losses.items():
                batch_losses[key] = value / acc_iter_size

            optimizer.step()
            torch.cuda.empty_cache()

            tbar.set_description('Loss: {:.3g}'.format(
                batch_losses['total_loss']))
            tbar.update(1)

            # Save the running metrics and losses
            if not running_metrics:
                running_metrics = copy.deepcopy(batch_metrics)
            else:
                for key, value in batch_metrics.items():
                    running_metrics[key] += value

            if not running_losses:
                running_losses = copy.deepcopy(batch_losses)
            else:
                for key, value in batch_losses.items():
                    running_losses[key] += value

            # Logs
            if total_it % stat_interval == stat_interval - 1:
                # Print / save logs
                logger.info("Epoch {0:d} - It. {1:d}: loss = {2:.3f}".format(
                    epoch_it, total_it,
                    running_losses['total_loss'] / stat_interval))

                for key, value in running_losses.items():
                    tboard_logger.add_scalar("Train/{}".format(key),
                                             value / stat_interval, total_it)
                    # Reinitialize the values
                    running_losses[key] = 0

                for key, value in running_metrics.items():
                    tboard_logger.add_scalar("Train/{}".format(key),
                                             value / stat_interval, total_it)
                    # Reinitialize the values
                    running_metrics[key] = 0

                start = time.time()

            # Run validation
            if total_it % val_interval == val_interval - 1:
                logger.info("Starting the validation")
                val_losses, val_metrics = trainer.validate(val_loader)

                for key, value in val_losses.items():
                    tboard_logger.add_scalar("Val/{}".format(key), value,
                                             total_it)

                for key, value in val_metrics.items():
                    tboard_logger.add_scalar("Val/{}".format(key), value,
                                             total_it)

                logger.info(
                    "VALIDATION -It. {0:d}: total loss: {1:.3f}.".format(
                        total_it, val_losses['total_loss']))

                if val_losses['total_loss'] < metric_val_best:
                    metric_val_best = val_losses['total_loss']
                    logger.info('New best model (loss: {:.4f})'.format(
                        metric_val_best))

                    save_checkpoint(os.path.join(path2log, 'model_best.pt'),
                                    epoch=epoch_it,
                                    it=total_it,
                                    model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    config=cfg,
                                    best_val=metric_val_best)
                else:
                    save_checkpoint(os.path.join(
                        path2log, 'model_{}.pt'.format(total_it)),
                                    epoch=epoch_it,
                                    it=total_it,
                                    model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    config=cfg,
                                    best_val=val_losses['total_loss'])

        # After the epoch if finished update the scheduler
        scheduler.step()

    # Quit after the maximum number of epochs is reached
    logger.info(
        'Training completed after {} Epochs ({} it) with best val metric ({})={}'
        .format(epoch_it, it, model_selection_metric, metric_val_best))
Пример #29
0
def main():
    global best_test_bpd

    last_checkpoints = []
    lipschitz_constants = []
    ords = []

    # if args.resume:
    #     validate(args.begin_epoch - 1, model, ema)
    for epoch in range(args.begin_epoch, args.nepochs):
        print("epoch %d of %d, %s" % (epoch, args.nepochs, datetime.now()))
        stdout.flush()

        logger.info('Current LR {}'.format(optimizer.param_groups[0]['lr']))

        train(epoch, model)
        lipschitz_constants.append(get_lipschitz_constants(model))
        ords.append(get_ords(model))
        logger.info('Lipsh: {}'.format(pretty_repr(lipschitz_constants[-1])))
        logger.info('Order: {}'.format(pretty_repr(ords[-1])))

        if args.ema_val:
            test_bpd = validate(epoch, model, ema)
        else:
            test_bpd = validate(epoch, model)

        if args.scheduler and scheduler is not None:
            scheduler.step()

        if test_bpd < best_test_bpd:
            best_test_bpd = test_bpd
            utils.save_checkpoint(
                {
                    'state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'args': args,
                    'ema': ema,
                    'test_bpd': test_bpd,
                },
                os.path.join(args.save, 'models'),
                epoch,
                last_checkpoints,
                num_checkpoints=5)

            torch.save(
                {
                    "density_model": model,
                    "args": args,
                    "input_size": input_size,
                    "n_classes": n_classes,
                    "im_dim": im_dim,
                    "epoch": epoch
                },
                "/scratch/shared/nfs1/xuji/generalization/models/%s_resflow_full_model.pytorch"
                % args.data)
            print("saved best")

        torch.save(
            {
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'args': args,
                'ema': ema,
                'test_bpd': test_bpd,
            }, os.path.join(args.save, 'models', 'most_recent.pth'))