def train(args):

    if args.ckpt_path and not args.use_pretrained:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        if args.use_pretrained:
            model.load_pretrained(args.ckpt_path, args.gpu_ids)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    if args.use_pretrained or args.fine_tune:
        parameters = model.module.fine_tuning_parameters(
            args.fine_tuning_boundary, args.fine_tuning_lr)
    else:
        parameters = model.parameters()
    optimizer = util.get_optimizer(parameters, args)
    lr_scheduler = util.get_scheduler(optimizer, args)
    if args.ckpt_path and not args.use_pretrained and not args.fine_tune:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    cls_loss_fn = util.get_loss_fn(is_classification=True,
                                   dataset=args.dataset,
                                   size_average=False)
    data_loader_fn = data_loader.__dict__[args.data_loader]
    train_loader = data_loader_fn(args, phase='train', is_training=True)
    logger = TrainLogger(args, len(train_loader.dataset),
                         train_loader.dataset.pixel_dict)
    eval_loaders = [data_loader_fn(args, phase='val', is_training=False)]
    evaluator = ModelEvaluator(args.do_classify, args.dataset, eval_loaders,
                               logger, args.agg_method, args.num_visuals,
                               args.max_eval, args.epochs_per_eval)
    saver = ModelSaver(args.save_dir, args.epochs_per_save, args.max_ckpts,
                       args.best_ckpt_metric, args.maximize_metric)

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, target_dict in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                inputs.to(args.device)
                cls_logits = model.forward(inputs)
                cls_targets = target_dict['is_abnormal']
                cls_loss = cls_loss_fn(cls_logits, cls_targets.to(args.device))
                loss = cls_loss.mean()

                logger.log_iter(inputs, cls_logits, target_dict,
                                cls_loss.mean(), optimizer)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()
            util.step_scheduler(lr_scheduler, global_step=logger.global_step)

        metrics, curves = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.best_ckpt_metric, None))
        logger.end_epoch(metrics, curves)
        util.step_scheduler(lr_scheduler,
                            metrics,
                            epoch=logger.epoch,
                            best_ckpt_metric=args.best_ckpt_metric)
Exemplo n.º 2
0
def train(args):
    # Get loader for outer loop training
    loader = get_loader(args)
    target_image_shape = loader.dataset.target_image_shape
    setattr(args, 'target_image_shape', target_image_shape)

    # Load model
    model_fn = models.__dict__[args.model]
    model = model_fn(**vars(args))
    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Print model parameters
    print('Model parameters: name, size, mean, std')
    for name, param in model.named_parameters():
        print(name, param.size(), torch.mean(param), torch.std(param))

    # Get optimizer and loss
    parameters = model.parameters()
    optimizer = util.get_optimizer(parameters, args)
    loss_fn = util.get_loss_fn(args.loss_fn, args)

    z_loss_fn = util.get_loss_fn(args.loss_fn, args)

    # Get logger, saver
    logger = TrainLogger(args)
    saver = ModelSaver(args)

    print(f'Logs: {logger.log_dir}')
    print(f'Ckpts: {args.save_dir}')

    # Train model
    logger.log_hparams(args)
    batch_size = args.batch_size
    while not logger.is_finished_training():
        logger.start_epoch()

        for input_noise, target_image, mask, z_test_target, z_test in loader:
            logger.start_iter()

            if torch.cuda.is_available():
                input_noise = input_noise.to(args.device)  #.cuda()
                target_image = target_image.cuda()
                mask = mask.cuda()
                z_test = z_test.cuda()
                z_test_target = z_test_target.cuda()

            masked_target_image = target_image * mask
            obscured_target_image = target_image * (1.0 - mask)

            # Input is noise tensor, target is image
            model.train()
            with torch.set_grad_enabled(True):
                if args.use_intermediate_logits:
                    logits = model.forward(input_noise).float()
                    probs = F.sigmoid(logits)

                    # Debug logits and diffs
                    logger.debug_visualize(
                        [logits, logits * mask, logits * (1.0 - mask)],
                        unique_suffix='logits-train')
                else:
                    probs = model.forward(input_noise).float()

                # With backprop, calculate (1) masked loss, loss when mask is applied.
                # Loss is done elementwise without reduction, so must take mean after.
                # Easier for debugging.
                masked_probs = probs * mask
                masked_loss = torch.zeros(1,
                                          requires_grad=True).to(args.device)
                masked_loss = loss_fn(masked_probs, masked_target_image).mean()

                masked_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            # Without backprop, calculate (2) full loss on the entire image,
            # And (3) the obscured loss, region obscured by mask.
            model.eval()
            with torch.no_grad():
                if args.use_intermediate_logits:
                    logits_eval = model.forward(input_noise).float()
                    probs_eval = F.sigmoid(logits_eval)

                    # Debug logits and diffs
                    logger.debug_visualize([
                        logits_eval, logits_eval * mask, logits_eval *
                        (1.0 - mask)
                    ],
                                           unique_suffix='logits-eval')
                else:
                    probs_eval = model.forward(input_noise).float()

                masked_probs_eval = probs_eval * mask
                masked_loss_eval = torch.zeros(1)
                masked_loss_eval = loss_fn(masked_probs_eval,
                                           masked_target_image).mean()

                full_loss_eval = torch.zeros(1)
                full_loss_eval = loss_fn(probs_eval, target_image).mean()

                obscured_probs_eval = probs_eval * (1.0 - mask)
                obscured_loss_eval = torch.zeros(1)
                obscured_loss_eval = loss_fn(obscured_probs_eval,
                                             obscured_target_image).mean()

            # With backprop on only the input z, (4) run one step of z-test and get z-loss
            z_optimizer = util.get_optimizer([z_test.requires_grad_()], args)
            with torch.set_grad_enabled(True):
                if args.use_intermediate_logits:
                    z_logits = model.forward(z_test).float()
                    z_probs = F.sigmoid(z_logits)
                else:
                    z_probs = model.forward(z_test).float()

                z_loss = torch.zeros(1, requires_grad=True).to(args.device)
                z_loss = z_loss_fn(z_probs, z_test_target).mean()

                z_loss.backward()
                z_optimizer.step()
                z_optimizer.zero_grad()

            if z_loss < args.max_z_test_loss:  # TODO: include this part into the metrics/saver stuff below
                # Save MSE on obscured region
                final_metrics = {'final/score': obscured_loss_eval.item()}
                logger._log_scalars(final_metrics)
                print('z loss', z_loss)
                print('Final MSE value', obscured_loss_eval)

            # TODO: Make a function for metrics - or at least make sure dict includes all possible best ckpt metrics
            metrics = {'masked_loss': masked_loss.item()}
            saver.save(logger.global_step,
                       model,
                       optimizer,
                       args.device,
                       metric_val=metrics.get(args.best_ckpt_metric, None))
            # Log both train and eval model settings, and visualize their outputs
            logger.log_status(
                inputs=input_noise,
                targets=target_image,
                probs=probs,
                masked_probs=masked_probs,
                masked_loss=masked_loss,
                probs_eval=probs_eval,
                masked_probs_eval=masked_probs_eval,
                obscured_probs_eval=obscured_probs_eval,
                masked_loss_eval=masked_loss_eval,
                obscured_loss_eval=obscured_loss_eval,
                full_loss_eval=full_loss_eval,
                z_target=z_test_target,
                z_probs=z_probs,
                z_loss=z_loss,
                save_preds=args.save_preds,
            )

            logger.end_iter()

        logger.end_epoch()

    # Last log after everything completes
    logger.log_status(
        inputs=input_noise,
        targets=target_image,
        probs=probs,
        masked_probs=masked_probs,
        masked_loss=masked_loss,
        probs_eval=probs_eval,
        masked_probs_eval=masked_probs_eval,
        obscured_probs_eval=obscured_probs_eval,
        masked_loss_eval=masked_loss_eval,
        obscured_loss_eval=obscured_loss_eval,
        full_loss_eval=full_loss_eval,
        z_target=z_test_target,
        z_probs=z_probs,
        z_loss=z_loss,
        save_preds=args.save_preds,
        force_visualize=True,
    )
Exemplo n.º 3
0
def train(args):
    # Get model
    model = models.__dict__[args.model](args)
    if args.ckpt_path:
        model = ModelSaver.load_model(model,
                                      args.ckpt_path,
                                      args.gpu_ids,
                                      is_training=True)
    model = model.to(args.device)
    model.train()

    # Get loader, logger, and saver
    train_loader, val_loader = get_data_loaders(args)
    logger = TrainLogger(args, model, dataset_len=len(train_loader.dataset))
    saver = ModelSaver(args.save_dir,
                       args.max_ckpts,
                       metric_name=args.metric_name,
                       maximize_metric=args.maximize_metric,
                       keep_topk=True)

    # Train
    while not logger.is_finished_training():
        logger.start_epoch()
        for batch in train_loader:
            logger.start_iter()

            # Train over one batch
            model.set_inputs(batch['src'], batch['tgt'])
            model.train_iter()

            logger.end_iter()

            # Evaluate
            if logger.global_step % args.iters_per_eval < args.batch_size:
                criteria = {'MSE_src2tgt': mse, 'MSE_tgt2src': mse}
                stats = evaluate(model, val_loader, criteria)
                logger.log_scalars({'val_' + k: v for k, v in stats.items()})
                saver.save(logger.global_step, model, stats[args.metric_name],
                           args.device)

        logger.end_epoch()
Exemplo n.º 4
0
def train(args):
    """Train model.

    Args:
        args: Command line arguments.
        model: Classifier model to train.
    """
    # Set up model
    model = models.__dict__[args.model](**vars(args))
    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)

    # Set up data loader
    train_loader, test_loader, classes = get_cifar_loaders(
        args.batch_size, args.num_workers)

    # Set up optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.sgd_momentum,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step,
                                          args.lr_decay_gamma)
    loss_fn = nn.CrossEntropyLoss().to(args.device)

    # Set up checkpoint saver
    saver = ModelSaver(model,
                       optimizer,
                       scheduler,
                       args.save_dir, {'model': args.model},
                       max_to_keep=args.max_ckpts,
                       device=args.device)

    # Train
    logger = TrainLogger(args, len(train_loader.dataset))

    while not logger.is_finished_training():
        logger.start_epoch()

        # Train for one epoch
        model.train()
        for inputs, labels in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                # Forward
                outputs = model.forward(inputs.to(args.device))
                loss = loss_fn(outputs, labels.to(args.device))
                loss_item = loss.item()

                # Backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter({'loss': loss_item})

        # Evaluate on validation set
        val_loss = evaluate(model, test_loader, loss_fn, device=args.device)
        logger.write('[epoch {}]: val_loss: {:.3g}'.format(
            logger.epoch, val_loss))
        logger.write_summaries({'loss': val_loss}, phase='val')
        if logger.epoch in args.save_epochs:
            saver.save(logger.epoch, val_loss)

        logger.end_epoch()
        scheduler.step()
Exemplo n.º 5
0
def main():
    parser = ArgParser()
    args = parser.parse_args()

    gen = Generator(args.latent_dim).to(args.device)
    disc = Discriminator().to(args.device)
    if args.device != 'cpu':
        gen = nn.DataParallel(gen, args.gpu_ids)
        disc = nn.DataParallel(disc, args.gpu_ids)
    # gen = gen.apply(weights_init)
    # disc = disc.apply(weights_init)

    gen_opt = torch.optim.RMSprop(gen.parameters(), lr=args.lr)
    disc_opt = torch.optim.RMSprop(disc.parameters(), lr=args.lr)
    gen_scheduler = torch.optim.lr_scheduler.LambdaLR(gen_opt, lr_lambda=lr_lambda(args.num_epochs))
    disc_scheduler = torch.optim.lr_scheduler.LambdaLR(disc_opt, lr_lambda=lr_lambda(args.num_epochs))
    disc_loss_fn = DiscriminatorLoss().to(args.device)
    gen_loss_fn = GeneratorLoss().to(args.device)

    # dataset = Dataset()
    dataset = MNISTDataset()
    loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)

    logger = TrainLogger(args, len(loader), phase=None)
    logger.log_hparams(args)

    if args.privacy_noise_multiplier != 0:
        privacy_engine = PrivacyEngine(
            disc,
            batch_size=args.batch_size,
            sample_size=len(dataset),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=.8,
            max_grad_norm=0.02,
            batch_first=True,
        )
        privacy_engine.attach(disc_opt)
        privacy_engine.to(args.device)

    for epoch in range(args.num_epochs):
        logger.start_epoch()
        for cur_step, img in enumerate(tqdm(loader, dynamic_ncols=True)):
            logger.start_iter()
            img = img.to(args.device)
            fake, disc_loss = None, None
            for _ in range(args.step_train_discriminator):
                disc_opt.zero_grad()
                fake_noise = get_noise(args.batch_size, args.latent_dim, device=args.device)
                fake = gen(fake_noise)
                disc_loss = disc_loss_fn(img, fake, disc)
                disc_loss.backward()
                disc_opt.step()

            gen_opt.zero_grad()
            fake_noise_2 = get_noise(args.batch_size, args.latent_dim, device=args.device)
            fake_2 = gen(fake_noise_2)
            gen_loss = gen_loss_fn(img, fake_2, disc)
            gen_loss.backward()
            gen_opt.step()
            if args.privacy_noise_multiplier != 0:
                epsilon, best_alpha = privacy_engine.get_privacy_spent(args.privacy_delta)

            logger.log_iter_gan_from_latent_vector(img, fake, gen_loss, disc_loss, epsilon if args.privacy_noise_multiplier != 0 else 0)
            logger.end_iter()

        logger.end_epoch()
        gen_scheduler.step()
        disc_scheduler.step()
Exemplo n.º 6
0
def main():
    args = easydict.EasyDict({
        # "dataroot": "/mnt/gold/users/s18150/mywork/pytorch/data/gan",
        "dataroot": "/mnt/gold/users/s18150/mywork/pytorch/data",
        "save_dir": "./",
        "prefix": "test",
        "workers": 8,
        "batch_size": 128,
        "image_size": 32,
        # "image_size": 28,
        # "nc": 3,
        "nc": 1,
        "nz": 100,
        "ngf": 32,
        "ndf": 32,
        # "ngf": 28,
        # "ndf": 64,
        "epochs": 1,
        "lr": 0.0002,
        "beta1": 0.5,
        "gpu": 7,
        "use_cuda": True,
        "feature_matching": True,
        "mini_batch": True,
        "iters": 50000,
        "label_batch_size": 100,
        "unlabel_batch_size": 100,
        "test_batch_size": 10,
        "out_dir": './result',
        "log_interval": 500,
        "label_num": 20
    })

    manualSeed = 999
    np.random.seed(manualSeed)
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    device = torch.device(
        'cuda:{}'.format(args.gpu) if args.use_cuda else 'cpu')

    # transform = transforms.Compose([
    #     transforms.Resize(args.image_size),
    #     transforms.CenterCrop(args.image_size),
    #     transforms.ToTensor(),
    #     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    # ])
    #
    # dataset = dset.ImageFolder(root=args.dataroot, transform=transform)
    # dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
    #                                          shuffle=True, num_workers=args.workers)

    data_iterators = dataset.get_iters(root_path=args.dataroot,
                                       l_batch_size=args.label_batch_size,
                                       ul_batch_size=args.unlabel_batch_size,
                                       test_batch_size=args.test_batch_size,
                                       workers=args.workers,
                                       n_labeled=args.label_num)

    trainloader_label = data_iterators['labeled']
    trainloader_unlabel = data_iterators['unlabeled']
    testloader = data_iterators['test']

    # Generator用のモデルのインスタンス作成
    netG = net.Generator(args.nz, args.ngf, args.nc).to(device)
    # Generator用のモデルの初期値を設定
    netG.apply(net.weights_init)

    # Discriminator用のモデルのインスタンス作成
    netD = net.Discriminator(args.nc, args.ndf, device, args.batch_size,
                             args.mini_batch).to(device)
    # Discriminator用のモデルの初期値を設定
    netD.apply(net.weights_init)

    # BCE Loss classのインスタンスを作成
    criterionD = nn.CrossEntropyLoss()
    # criterionD = nn.BCELoss()

    if args.feature_matching is True:
        criterionG = nn.MSELoss(reduction='elementwise_mean')
    else:
        criterionG = nn.BCELoss()

    # Generatorに入力するノイズをバッチごとに作成 (バッチ数は64)
    # これはGeneratorの結果を描画するために使用する
    fixed_noise = torch.randn(64, args.nz, 1, 1, device=device)

    # 最適化関数のインスタンスを作成
    optimizerD = optim.Adam(netD.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))

    logger = TrainLogger(args)
    r = run.NNRun(netD, netG, optimizerD, optimizerG, criterionD, criterionG,
                  device, fixed_noise, logger, args)

    # 学習
    # r.train(dataloader)
    r.train(trainloader_label, trainloader_unlabel, testloader)
Exemplo n.º 7
0
def train(args):

    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = optim.get_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = nn.CrossEntropyLoss()
    train_loader = CIFARLoader('train', args.batch_size, args.num_workers)
    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [CIFARLoader('val', args.batch_size, args.num_workers)]
    evaluator = ModelEvaluator(eval_loaders, logger, args.max_eval,
                               args.epochs_per_eval)
    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                logits = model.forward(inputs.to(args.device))
                loss = loss_fn(logits, targets.to(args.device))

                logger.log_iter(loss)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics)
        optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
Exemplo n.º 8
0
def main(args):
    write_args(args)

    # Set up main device and scale batch size
    device = 'cuda' if torch.cuda.is_available() and args.gpu_ids else 'cpu'
    print(device)

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Build Model
    print("Building model...")
    model_fn = models.__dict__[args.model]
    model = model_fn(args, device)
    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(device)

    # Loss fn
    loss_fn = get_loss(args.model)

    # Data loaders
    train_loader = get_dataloader(args, "train")
    val_loader = get_dataloader(args, "val")

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Logger and Resume
    if args.resume:
        resume_path = os.path.join(args.save_dir, "current.pth.tar")
        print("Resuming from checkpoint at {}".format(resume_path))
        checkpoint = torch.load(resume_path)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        start_iter = checkpoint['iter']
        global_step = start_epoch * len(train_loader)
        logger = TrainLogger(args, start_epoch, global_step)
        logger.best_val_loss = checkpoint['val_loss']
        print(start_epoch)
        print(start_iter)
    else:
        start_epoch = 0
        global_step = 0
        start_iter = 0
        logger = TrainLogger(args, start_epoch, global_step)

    # Sampler
    sampler = get_sampler(args.model, 0, 16, args.size, args.input_c,
                          args.save_dir, device)

    for i in range(start_epoch, args.num_epochs):

        # Train
        model.train()
        logger.start_epoch()
        for j, image in enumerate(train_loader):
            if j < start_iter:
                logger.end_iter()
                continue

            # Sample and Eval
            if j % 250 == 0:
                print("Sampling...")
                sampler.sample(model, i, j)

                with torch.no_grad():
                    logger.val_loss_meter.reset()
                    model.eval()
                    for image in tqdm(val_loader):
                        image = image.to(device)
                        output = model(image)
                        loss = loss_fn(output, image)
                        logger.val_loss_meter.update(loss)
                logger.has_improved(model, optimizer, j)
                logger._log_scalars({'val-loss': logger.val_loss_meter.avg})
                model.train()

            logger.start_iter()
            image = image.to(device)
            optimizer.zero_grad()
            output = model(image)
            loss = loss_fn(output, image)
            loss.backward()
            for group in optimizer.param_groups:
                utils.clip_grad_norm_(group['params'], args.max_grad_norm, 2)
            optimizer.step()

            logger.log_iter(loss)
            logger.end_iter()
        logger.end_epoch(None, optimizer)
        start_iter = 0
Exemplo n.º 9
0
def train(args):
    train_loader = get_loader(args=args)
    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        args.D_in = train_loader.D_in
        model = model_fn(**vars(args))
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = optim.get_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = optim.get_loss_fn(args.loss_fn, args)

    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [
        get_loader(args, phase='train', is_training=False),
        get_loader(args, phase='valid', is_training=False)
    ]
    evaluator = ModelEvaluator(args, eval_loaders, logger, args.max_eval,
                               args.epochs_per_eval)

    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for src, tgt in train_loader:
            logger.start_iter()
            with torch.set_grad_enabled(True):
                pred_params = model.forward(src.to(args.device))
                ages = src[:, 1]
                loss = loss_fn(pred_params, tgt.to(args.device),
                               ages.to(args.device), args.use_intvl)
                #loss = loss_fn(pred_params, tgt.to(args.device), src.to(args.device), args.use_intvl)
                logger.log_iter(src, pred_params, tgt, loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        # print(metrics)
        saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device,\
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics=metrics)
Exemplo n.º 10
0
def train(args):
    """Run training loop with the given args.

    The function consists of the following steps:
        1. Load model: gets the model from a checkpoint or from models/models.py.
        2. Load optimizer and learning rate scheduler.
        3. Get data loaders and class weights.
        4. Get loss functions: cross entropy loss and weighted loss functions.
        5. Get logger, evaluator, and saver.
        6. Run training loop, evaluate and save model periodically.
    """

    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    task_sequence = TASK_SEQUENCES[data_args.task_sequence]

    # Get model
    if model_args.ckpt_path:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path,
                                                 args.gpu_ids, model_args,
                                                 data_args)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[model_args.model]
        model = model_fn(task_sequence, model_args)
        if model_args.hierarchy:
            model = models.HierarchyWrapper(model, task_sequence)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = util.get_optimizer(model.parameters(), optim_args)
    lr_scheduler = util.get_scheduler(optimizer, optim_args)
    if model_args.ckpt_path:
        ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids,
                                  optimizer, lr_scheduler)

    # Get loaders and class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path
    #TODO: Remove this when we decide which transformation to use in the end
    #transforms_imgaug = ImgAugTransform()
    train_loader = get_loader(data_args,
                              transform_args,
                              train_csv_name,
                              task_sequence,
                              data_args.su_train_frac,
                              data_args.nih_train_frac,
                              data_args.pocus_train_frac,
                              data_args.tcga_train_frac,
                              0,
                              0,
                              args.batch_size,
                              frontal_lateral=model_args.frontal_lateral,
                              is_training=True,
                              shuffle=True,
                              transform=model_args.transform,
                              normalize=model_args.normalize)
    eval_loaders = get_eval_loaders(data_args,
                                    transform_args,
                                    task_sequence,
                                    args.batch_size,
                                    frontal_lateral=model_args.frontal_lateral,
                                    normalize=model_args.normalize)
    class_weights = train_loader.dataset.class_weights
    print(" class weights:")
    print(class_weights)

    # Get loss functions
    uw_loss_fn = get_loss_fn('cross_entropy',
                             args.device,
                             model_args.model_uncertainty,
                             args.has_tasks_missing,
                             class_weights=class_weights)

    w_loss_fn = get_loss_fn('weighted_loss',
                            args.device,
                            model_args.model_uncertainty,
                            args.has_tasks_missing,
                            mask_uncertain=False,
                            class_weights=class_weights)

    # Get logger, evaluator and saver
    logger = TrainLogger(logger_args, args.start_epoch,
                         args.num_epochs, args.batch_size,
                         len(train_loader.dataset), args.device)

    eval_args = {}
    eval_args['num_visuals'] = logger_args.num_visuals
    eval_args['iters_per_eval'] = logger_args.iters_per_eval
    eval_args['has_missing_tasks'] = args.has_tasks_missing
    eval_args['model_uncertainty'] = model_args.model_uncertainty
    eval_args['class_weights'] = class_weights
    eval_args['max_eval'] = logger_args.max_eval
    eval_args['device'] = args.device
    eval_args['optimizer'] = args.optimizer
    evaluator = get_evaluator('classification', eval_loaders, logger,
                              eval_args)

    print("Eval Loaders: %d" % len(eval_loaders))
    saver = ModelSaver(**vars(logger_args))

    metrics = None
    lr_step = 0
    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, info_dict in train_loader:

            logger.start_iter()

            # Evaluate and save periodically
            metrics, curves = evaluator.evaluate(model, args.device,
                                                 logger.global_step)
            logger.plot_metrics(metrics)
            metric_val = metrics.get(logger_args.metric_name, None)

            assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None
            saver.save(logger.global_step,
                       logger.epoch,
                       model,
                       optimizer,
                       lr_scheduler,
                       args.device,
                       metric_val=metric_val)
            lr_step = util.step_scheduler(
                lr_scheduler,
                metrics,
                lr_step,
                best_ckpt_metric=logger_args.metric_name)

            # Input: [batch_size, channels, width, height]

            with torch.set_grad_enabled(True):

                logits = model.forward(inputs.to(args.device))

                unweighted_loss = uw_loss_fn(logits, targets.to(args.device))

                weighted_loss = w_loss_fn(logits, targets.to(
                    args.device)) if w_loss_fn else None

                logger.log_iter(inputs, logits, targets, unweighted_loss,
                                weighted_loss, optimizer)

                optimizer.zero_grad()
                if args.loss_fn == 'weighted_loss':
                    weighted_loss.backward()
                else:
                    unweighted_loss.backward()
                optimizer.step()

            logger.end_iter()

        logger.end_epoch(metrics, optimizer)
Exemplo n.º 11
0
def train(args):
    write_args(args)

    model_args = args.model_args
    data_args = args.data_args
    logger_args = args.logger_args

    print(f"Training {logger_args.name}")

    power_constraint = PowerConstraint()
    possible_inputs = get_md_set(model_args.md_len)
    channel = get_channel(data_args.channel, model_args.modelfree, data_args)

    model = AutoEncoder(model_args, data_args, power_constraint, channel,
                        possible_inputs)
    enc_scheduler = get_scheduler(model_args.scheduler, model_args.decay,
                                  model_args.patience)
    dec_scheduler = get_scheduler(model_args.scheduler, model_args.decay,
                                  model_args.patience)

    enc_scheduler.set_model(model.trainable_encoder)
    dec_scheduler.set_model(model.trainable_decoder)
    dataset_size = data_args.batch_size * data_args.batches_per_epoch * data_args.num_epochs
    loader = InputDataloader(data_args.batch_size, data_args.block_length,
                             dataset_size)
    loader = loader.example_generator()
    logger = TrainLogger(logger_args.save_dir, logger_args.name,
                         data_args.num_epochs, logger_args.iters_per_print)

    saver = ModelSaver(logger_args.save_dir, logger)

    enc_scheduler.on_train_begin()
    dec_scheduler.on_train_begin()

    while True:  # Loop until StopIteration
        try:
            metrics = None
            logger.start_epoch()
            for step in range(data_args.batches_per_epoch //
                              (model_args.train_ratio + 1)):
                # encoder train
                logger.start_iter()
                msg = next(loader)
                metrics = model.train_encoder(msg)
                logger.log_iter(metrics)
                logger.end_iter()

                # decoder train
                for _ in range(model_args.train_ratio):
                    logger.start_iter()
                    msg = next(loader)
                    metrics = model.train_decoder(msg)
                    logger.log_iter(metrics)
                    logger.end_iter()
            logger.end_epoch(None)

            if model_args.modelfree:
                model.Pi.std *= model_args.sigma_decay

            enc_scheduler.on_epoch_end(logger.epoch, logs=metrics)
            dec_scheduler.on_epoch_end(logger.epoch, logs=metrics)

            if logger.has_improved():
                saver.save(model)

            if logger.notImprovedCounter >= 7:
                break
        except StopIteration:
            break
Exemplo n.º 12
0
def train_classifier(args, model):
    """Train a classifier and save its first-layer weights.

    Args:
        args: Command line arguments.
        model: Classifier model to train.
    """
    # Set up data loader
    train_loader, test_loader, classes = get_data_loaders(
        args.dataset, args.batch_size, args.num_workers)

    # Set up model
    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)

    fd = None
    if args.use_fd:
        fd = models.filter_discriminator()
        fd = nn.DataParallel(fd, args.gpu_ids)
        fd = fd.to(args.device)

    # Set up optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.sgd_momentum,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step,
                                          args.lr_decay_gamma)
    if args.model == 'fd':
        post_process = nn.Sigmoid()
        loss_fn = nn.MSELoss().to(args.device)
    else:
        post_process = nn.Sequential()  # Identity
        loss_fn = nn.CrossEntropyLoss().to(args.device)

    # Set up checkpoint saver
    saver = ModelSaver(model,
                       optimizer,
                       scheduler,
                       args.save_dir, {'model': args.model},
                       max_to_keep=args.max_ckpts,
                       device=args.device)

    # Train
    logger = TrainLogger(args, len(train_loader.dataset))
    if args.save_all:
        # Save initialized model weights with validation loss as random
        saver.save(0, math.log(args.num_classes))
    while not logger.is_finished_training():
        logger.start_epoch()

        # Train for one epoch
        model.train()
        fd_lambda = get_fd_lambda(args, logger.epoch)
        for inputs, labels in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                # Forward
                outputs = model.forward(inputs.to(args.device))
                outputs = post_process(outputs)
                loss = loss_fn(outputs, labels.to(args.device))
                loss_item = loss.item()

                fd_loss = torch.zeros([],
                                      dtype=torch.float32,
                                      device='cuda' if args.gpu_ids else 'cpu')
                tp_total = torch.zeros(
                    [],
                    dtype=torch.float32,
                    device='cuda' if args.gpu_ids else 'cpu')
                if fd is not None:
                    # Forward FD
                    filters = get_layer_weights(model, filter_dict[args.model])
                    for i in range(0, filters.size(0), args.fd_batch_size):
                        fd_batch = filters[i:i + args.fd_batch_size]
                        tp_scores = F.sigmoid(fd.forward(fd_batch))
                        tp_total += tp_scores.sum()
                    fd_loss = 1. - tp_total / filters.size(0)

                fd_loss_item = fd_loss.item()
                loss += fd_lambda * fd_loss

                # Backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter({
                'std_loss': loss_item,
                'fd_loss': fd_loss_item,
                'loss': loss_item + fd_loss_item
            })

        # Evaluate on validation set
        val_loss = evaluate(model,
                            post_process,
                            test_loader,
                            loss_fn,
                            device=args.device)
        logger.write('[epoch {}]: val_loss: {:.3g}'.format(
            logger.epoch, val_loss))
        logger.write_summaries({'loss': val_loss}, phase='val')
        if args.save_all or logger.epoch in args.save_epochs:
            saver.save(logger.epoch, val_loss)

        logger.end_epoch()
        scheduler.step()
Exemplo n.º 13
0
def train(args):

    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(pretrained=args.pretrained)
        if args.pretrained:
            model.fc = nn.Linear(model.fc.in_features, args.num_classes)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    parameters = optim.get_parameters(model.module, args)
    optimizer = optim.get_optimizer(parameters, args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = nn.CrossEntropyLoss()
    train_loader = WhiteboardLoader(args.data_dir,
                                    'train',
                                    args.batch_size,
                                    shuffle=True,
                                    do_augment=True,
                                    num_workers=args.num_workers)
    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [
        WhiteboardLoader(args.data_dir,
                         'val',
                         args.batch_size,
                         shuffle=False,
                         do_augment=False,
                         num_workers=args.num_workers)
    ]
    evaluator = ModelEvaluator(eval_loaders, logger, args.epochs_per_eval,
                               args.max_eval, args.num_visuals)
    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, paths in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                logits = model.forward(inputs.to(args.device))
                loss = loss_fn(logits, targets.to(args.device))

                logger.log_iter(inputs, logits, targets, paths, loss)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            optim.step_scheduler(lr_scheduler, global_step=logger.global_step)
            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   args.model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics)
        optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
Exemplo n.º 14
0
def train(args):
    """Run training loop with the given args.

    The function consists of the following steps:
        1. Load model: gets the model from a checkpoint or from models/models.py.
        2. Load optimizer and learning rate scheduler.
        3. Get data loaders and class weights.
        4. Get loss functions: cross entropy loss and weighted loss functions.
        5. Get logger, evaluator, and saver.
        6. Run training loop, evaluate and save model periodically.
    """
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    task_sequence = TASK_SEQUENCES[data_args.task_sequence]
    print('gpus: ', args.gpu_ids)
    # Get model
    if model_args.ckpt_path:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path, args.gpu_ids, model_args, data_args)
        if not logger_args.restart_epoch_count:
            args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[model_args.model]
        model = model_fn(task_sequence, model_args)
        num_covars = len(model_args.covar_list.split(';'))
        model.transform_model_shape(len(task_sequence), num_covars)
        if model_args.hierarchy:
            model = models.HierarchyWrapper(model, task_sequence)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = util.get_optimizer(model.parameters(), optim_args)
    lr_scheduler = util.get_scheduler(optimizer, optim_args)

    # The optimizer is loaded from the ckpt if one exists and the new model
    # architecture is the same as the old one (classifier is not transformed).
    if model_args.ckpt_path and not model_args.transform_classifier:
        ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids, optimizer, lr_scheduler)

    # Get loaders and class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path

    # Put all CXR training fractions into one dictionary and pass it to the loader
    cxr_frac = {'pocus': data_args.pocus_train_frac, 'hocus': data_args.hocus_train_frac,
                'pulm': data_args.pulm_train_frac}
    train_loader = get_loader(data_args,
                              transform_args,
                              train_csv_name,
                              task_sequence,
                              data_args.su_train_frac,
                              data_args.nih_train_frac,
                              cxr_frac,
                              data_args.tcga_train_frac,
                              args.batch_size,
                              frontal_lateral=model_args.frontal_lateral,
                              is_training=True,
                              shuffle=True,
                              covar_list=model_args.covar_list,
                              fold_num=data_args.fold_num)
    eval_loaders = get_eval_loaders(data_args,
                                    transform_args,
                                    task_sequence,
                                    args.batch_size,
                                    frontal_lateral=model_args.frontal_lateral,
                                    covar_list=model_args.covar_list,
                                    fold_num=data_args.fold_num)
    class_weights = train_loader.dataset.class_weights

    # Get loss functions
    uw_loss_fn = get_loss_fn(args.loss_fn, args.device, model_args.model_uncertainty,
        args.has_tasks_missing, class_weights=class_weights)
    w_loss_fn = get_loss_fn('weighted_loss', args.device, model_args.model_uncertainty,
        args.has_tasks_missing, class_weights=class_weights)

    # Get logger, evaluator and saver
    logger = TrainLogger(logger_args, args.start_epoch, args.num_epochs, args.batch_size,
        len(train_loader.dataset), args.device, normalization=transform_args.normalization)
    
    eval_args = {}
    eval_args['num_visuals'] = logger_args.num_visuals
    eval_args['iters_per_eval'] = logger_args.iters_per_eval
    eval_args['has_missing_tasks'] = args.has_tasks_missing
    eval_args['model_uncertainty'] = model_args.model_uncertainty
    eval_args['class_weights'] = class_weights
    eval_args['max_eval'] = logger_args.max_eval
    eval_args['device'] = args.device
    eval_args['optimizer'] = optimizer
    evaluator = get_evaluator('classification', eval_loaders, logger, eval_args)

    print("Eval Loaders: %d" % len(eval_loaders))
    saver = ModelSaver(**vars(logger_args))

    metrics = None
    lr_step = 0
    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, info_dict, covars in train_loader:
            logger.start_iter()

            # Evaluate and save periodically
            metrics, curves = evaluator.evaluate(model, args.device, logger.global_step)
            logger.plot_metrics(metrics)
            metric_val = metrics.get(logger_args.metric_name, None)
            assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None
            saver.save(logger.global_step, logger.epoch, model, optimizer, lr_scheduler, args.device,
                       metric_val=metric_val, covar_list=model_args.covar_list)
            lr_step = util.step_scheduler(lr_scheduler, metrics, lr_step, best_ckpt_metric=logger_args.metric_name)

            # Input: [batch_size, channels, width, height]

            with torch.set_grad_enabled(True):
            # with torch.autograd.set_detect_anomaly(True):

                logits = model.forward([inputs.to(args.device), covars])

                # Scale up TB so that it's loss is counted for more if upweight_tb is True.
                if model_args.upweight_tb is True:
                    tb_targets = targets.narrow(1, 0, 1)
                    findings_targets = targets.narrow(1, 1, targets.shape[1] - 1)
                    tb_targets = tb_targets.repeat(1, targets.shape[1] - 1)
                    new_targets = torch.cat((tb_targets, findings_targets), 1)

                    tb_logits = logits.narrow(1, 0, 1)
                    findings_logits = logits.narrow(1, 1, logits.shape[1] - 1)
                    tb_logits = tb_logits.repeat(1, logits.shape[1] - 1)
                    new_logits = torch.cat((tb_logits, findings_logits), 1)
                else:
                    new_logits = logits
                    new_targets = targets

                    
                unweighted_loss = uw_loss_fn(new_logits, new_targets.to(args.device))

                weighted_loss = w_loss_fn(logits, targets.to(args.device)) if w_loss_fn else None

                logger.log_iter(inputs, logits, targets, unweighted_loss, weighted_loss, optimizer)

                optimizer.zero_grad()
                if args.loss_fn == 'weighted_loss':
                    weighted_loss.backward()
                else:
                    unweighted_loss.backward()
                optimizer.step()

            logger.end_iter()

        logger.end_epoch(metrics, optimizer)