示例#1
0
def train(args):

    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = optim.get_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = nn.CrossEntropyLoss()
    train_loader = CIFARLoader('train', args.batch_size, args.num_workers)
    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [CIFARLoader('val', args.batch_size, args.num_workers)]
    evaluator = ModelEvaluator(eval_loaders, logger, args.max_eval,
                               args.epochs_per_eval)
    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                logits = model.forward(inputs.to(args.device))
                loss = loss_fn(logits, targets.to(args.device))

                logger.log_iter(loss)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics)
        optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
示例#2
0
def train(args):
    train_loader = get_loader(args=args)
    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        args.D_in = train_loader.D_in
        model = model_fn(**vars(args))
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = optim.get_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = optim.get_loss_fn(args.loss_fn, args)

    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [
        get_loader(args, phase='train', is_training=False),
        get_loader(args, phase='valid', is_training=False)
    ]
    evaluator = ModelEvaluator(args, eval_loaders, logger, args.max_eval,
                               args.epochs_per_eval)

    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for src, tgt in train_loader:
            logger.start_iter()
            with torch.set_grad_enabled(True):
                pred_params = model.forward(src.to(args.device))
                ages = src[:, 1]
                loss = loss_fn(pred_params, tgt.to(args.device),
                               ages.to(args.device), args.use_intvl)
                #loss = loss_fn(pred_params, tgt.to(args.device), src.to(args.device), args.use_intvl)
                logger.log_iter(src, pred_params, tgt, loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        # print(metrics)
        saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device,\
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics=metrics)
示例#3
0
def train(args):

    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Set up population-based training client
    pbt_client = PBTClient(args.pbt_server_url, args.pbt_server_port, args.pbt_server_key, args.pbt_config_path)

    # Get optimizer and scheduler
    parameters = model.module.parameters()
    optimizer = optim.get_optimizer(parameters, args, pbt_client)
    ModelSaver.load_optimizer(args.ckpt_path, args.gpu_ids, optimizer)

    # Get logger, evaluator, saver
    train_loader = DataLoader(args, 'train', is_training_set=True)
    eval_loaders = [DataLoader(args, 'valid', is_training_set=False)]
    evaluator = ModelEvaluator(eval_loaders, args.epochs_per_eval,
                               args.max_eval, args.num_visuals, use_ten_crop=args.use_ten_crop)
    saver = ModelSaver(**vars(args))

    for _ in range(args.num_epochs):
        optim.update_hyperparameters(model.module, optimizer, pbt_client.hyperparameters())

        for inputs, targets in train_loader:
            with torch.set_grad_enabled(True):
                logits = model.forward(inputs.to(args.device))
                loss = F.binary_cross_entropy_with_logits(logits, targets.to(args.device))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        metrics = evaluator.evaluate(model, args.device)
        metric_val = metrics.get(args.metric_name, None)
        ckpt_path = saver.save(model, args.model, optimizer, args.device, metric_val)

        pbt_client.save(ckpt_path, metric_val)
        if pbt_client.should_exploit():
            # Exploit
            pbt_client.exploit()

            # Load model and optimizer parameters from exploited network
            model, ckpt_info = ModelSaver.load_model(pbt_client.parameters_path(), args.gpu_ids)
            model = model.to(args.device)
            model.train()
            ModelSaver.load_optimizer(pbt_client.parameters_path(), args.gpu_ids, optimizer)

            # Explore
            pbt_client.explore()
def train(args):

    if args.ckpt_path and not args.use_pretrained:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        if args.use_pretrained:
            model.load_pretrained(args.ckpt_path, args.gpu_ids)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    if args.use_pretrained or args.fine_tune:
        parameters = model.module.fine_tuning_parameters(
            args.fine_tuning_boundary, args.fine_tuning_lr)
    else:
        parameters = model.parameters()
    optimizer = util.get_optimizer(parameters, args)
    lr_scheduler = util.get_scheduler(optimizer, args)
    if args.ckpt_path and not args.use_pretrained and not args.fine_tune:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    cls_loss_fn = util.get_loss_fn(is_classification=True,
                                   dataset=args.dataset,
                                   size_average=False)
    data_loader_fn = data_loader.__dict__[args.data_loader]
    train_loader = data_loader_fn(args, phase='train', is_training=True)
    logger = TrainLogger(args, len(train_loader.dataset),
                         train_loader.dataset.pixel_dict)
    eval_loaders = [data_loader_fn(args, phase='val', is_training=False)]
    evaluator = ModelEvaluator(args.do_classify, args.dataset, eval_loaders,
                               logger, args.agg_method, args.num_visuals,
                               args.max_eval, args.epochs_per_eval)
    saver = ModelSaver(args.save_dir, args.epochs_per_save, args.max_ckpts,
                       args.best_ckpt_metric, args.maximize_metric)

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, target_dict in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                inputs.to(args.device)
                cls_logits = model.forward(inputs)
                cls_targets = target_dict['is_abnormal']
                cls_loss = cls_loss_fn(cls_logits, cls_targets.to(args.device))
                loss = cls_loss.mean()

                logger.log_iter(inputs, cls_logits, target_dict,
                                cls_loss.mean(), optimizer)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()
            util.step_scheduler(lr_scheduler, global_step=logger.global_step)

        metrics, curves = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.best_ckpt_metric, None))
        logger.end_epoch(metrics, curves)
        util.step_scheduler(lr_scheduler,
                            metrics,
                            epoch=logger.epoch,
                            best_ckpt_metric=args.best_ckpt_metric)
示例#5
0
文件: train.py 项目: yxliang/lca-code
def train(args):
    """Run training loop with the given args.

    The function consists of the following steps:
        1. Load model: gets the model from a checkpoint or from models/models.py.
        2. Load optimizer and learning rate scheduler.
        3. Get data loaders and class weights.
        4. Get loss functions: cross entropy loss and weighted loss functions.
        5. Get logger, evaluator, and saver.
        6. Run training loop, evaluate and save model periodically.
    """

    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    task_sequence = TASK_SEQUENCES[data_args.task_sequence]

    # Get model
    if model_args.ckpt_path:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path,
                                                 args.gpu_ids, model_args,
                                                 data_args)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[model_args.model]
        model = model_fn(task_sequence, model_args)
        if model_args.hierarchy:
            model = models.HierarchyWrapper(model, task_sequence)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = util.get_optimizer(model.parameters(), optim_args)
    lr_scheduler = util.get_scheduler(optimizer, optim_args)
    if model_args.ckpt_path:
        ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids,
                                  optimizer, lr_scheduler)

    # Get loaders and class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path
    #TODO: Remove this when we decide which transformation to use in the end
    #transforms_imgaug = ImgAugTransform()
    train_loader = get_loader(data_args,
                              transform_args,
                              train_csv_name,
                              task_sequence,
                              data_args.su_train_frac,
                              data_args.nih_train_frac,
                              data_args.pocus_train_frac,
                              data_args.tcga_train_frac,
                              0,
                              0,
                              args.batch_size,
                              frontal_lateral=model_args.frontal_lateral,
                              is_training=True,
                              shuffle=True,
                              transform=model_args.transform,
                              normalize=model_args.normalize)
    eval_loaders = get_eval_loaders(data_args,
                                    transform_args,
                                    task_sequence,
                                    args.batch_size,
                                    frontal_lateral=model_args.frontal_lateral,
                                    normalize=model_args.normalize)
    class_weights = train_loader.dataset.class_weights
    print(" class weights:")
    print(class_weights)

    # Get loss functions
    uw_loss_fn = get_loss_fn('cross_entropy',
                             args.device,
                             model_args.model_uncertainty,
                             args.has_tasks_missing,
                             class_weights=class_weights)

    w_loss_fn = get_loss_fn('weighted_loss',
                            args.device,
                            model_args.model_uncertainty,
                            args.has_tasks_missing,
                            mask_uncertain=False,
                            class_weights=class_weights)

    # Get logger, evaluator and saver
    logger = TrainLogger(logger_args, args.start_epoch,
                         args.num_epochs, args.batch_size,
                         len(train_loader.dataset), args.device)

    eval_args = {}
    eval_args['num_visuals'] = logger_args.num_visuals
    eval_args['iters_per_eval'] = logger_args.iters_per_eval
    eval_args['has_missing_tasks'] = args.has_tasks_missing
    eval_args['model_uncertainty'] = model_args.model_uncertainty
    eval_args['class_weights'] = class_weights
    eval_args['max_eval'] = logger_args.max_eval
    eval_args['device'] = args.device
    eval_args['optimizer'] = args.optimizer
    evaluator = get_evaluator('classification', eval_loaders, logger,
                              eval_args)

    print("Eval Loaders: %d" % len(eval_loaders))
    saver = ModelSaver(**vars(logger_args))

    metrics = None
    lr_step = 0
    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, info_dict in train_loader:

            logger.start_iter()

            # Evaluate and save periodically
            metrics, curves = evaluator.evaluate(model, args.device,
                                                 logger.global_step)
            logger.plot_metrics(metrics)
            metric_val = metrics.get(logger_args.metric_name, None)

            assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None
            saver.save(logger.global_step,
                       logger.epoch,
                       model,
                       optimizer,
                       lr_scheduler,
                       args.device,
                       metric_val=metric_val)
            lr_step = util.step_scheduler(
                lr_scheduler,
                metrics,
                lr_step,
                best_ckpt_metric=logger_args.metric_name)

            # Input: [batch_size, channels, width, height]

            with torch.set_grad_enabled(True):

                logits = model.forward(inputs.to(args.device))

                unweighted_loss = uw_loss_fn(logits, targets.to(args.device))

                weighted_loss = w_loss_fn(logits, targets.to(
                    args.device)) if w_loss_fn else None

                logger.log_iter(inputs, logits, targets, unweighted_loss,
                                weighted_loss, optimizer)

                optimizer.zero_grad()
                if args.loss_fn == 'weighted_loss':
                    weighted_loss.backward()
                else:
                    unweighted_loss.backward()
                optimizer.step()

            logger.end_iter()

        logger.end_epoch(metrics, optimizer)
示例#6
0
def train(args):

    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(pretrained=args.pretrained)
        if args.pretrained:
            model.fc = nn.Linear(model.fc.in_features, args.num_classes)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    parameters = optim.get_parameters(model.module, args)
    optimizer = optim.get_optimizer(parameters, args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = nn.CrossEntropyLoss()
    train_loader = WhiteboardLoader(args.data_dir,
                                    'train',
                                    args.batch_size,
                                    shuffle=True,
                                    do_augment=True,
                                    num_workers=args.num_workers)
    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [
        WhiteboardLoader(args.data_dir,
                         'val',
                         args.batch_size,
                         shuffle=False,
                         do_augment=False,
                         num_workers=args.num_workers)
    ]
    evaluator = ModelEvaluator(eval_loaders, logger, args.epochs_per_eval,
                               args.max_eval, args.num_visuals)
    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, paths in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                logits = model.forward(inputs.to(args.device))
                loss = loss_fn(logits, targets.to(args.device))

                logger.log_iter(inputs, logits, targets, paths, loss)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            optim.step_scheduler(lr_scheduler, global_step=logger.global_step)
            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   args.model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics)
        optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
示例#7
0
def train(args):
    """Run training loop with the given args.

    The function consists of the following steps:
        1. Load model: gets the model from a checkpoint or from models/models.py.
        2. Load optimizer and learning rate scheduler.
        3. Get data loaders and class weights.
        4. Get loss functions: cross entropy loss and weighted loss functions.
        5. Get logger, evaluator, and saver.
        6. Run training loop, evaluate and save model periodically.
    """
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    task_sequence = TASK_SEQUENCES[data_args.task_sequence]
    print('gpus: ', args.gpu_ids)
    # Get model
    if model_args.ckpt_path:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path, args.gpu_ids, model_args, data_args)
        if not logger_args.restart_epoch_count:
            args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[model_args.model]
        model = model_fn(task_sequence, model_args)
        num_covars = len(model_args.covar_list.split(';'))
        model.transform_model_shape(len(task_sequence), num_covars)
        if model_args.hierarchy:
            model = models.HierarchyWrapper(model, task_sequence)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = util.get_optimizer(model.parameters(), optim_args)
    lr_scheduler = util.get_scheduler(optimizer, optim_args)

    # The optimizer is loaded from the ckpt if one exists and the new model
    # architecture is the same as the old one (classifier is not transformed).
    if model_args.ckpt_path and not model_args.transform_classifier:
        ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids, optimizer, lr_scheduler)

    # Get loaders and class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path

    # Put all CXR training fractions into one dictionary and pass it to the loader
    cxr_frac = {'pocus': data_args.pocus_train_frac, 'hocus': data_args.hocus_train_frac,
                'pulm': data_args.pulm_train_frac}
    train_loader = get_loader(data_args,
                              transform_args,
                              train_csv_name,
                              task_sequence,
                              data_args.su_train_frac,
                              data_args.nih_train_frac,
                              cxr_frac,
                              data_args.tcga_train_frac,
                              args.batch_size,
                              frontal_lateral=model_args.frontal_lateral,
                              is_training=True,
                              shuffle=True,
                              covar_list=model_args.covar_list,
                              fold_num=data_args.fold_num)
    eval_loaders = get_eval_loaders(data_args,
                                    transform_args,
                                    task_sequence,
                                    args.batch_size,
                                    frontal_lateral=model_args.frontal_lateral,
                                    covar_list=model_args.covar_list,
                                    fold_num=data_args.fold_num)
    class_weights = train_loader.dataset.class_weights

    # Get loss functions
    uw_loss_fn = get_loss_fn(args.loss_fn, args.device, model_args.model_uncertainty,
        args.has_tasks_missing, class_weights=class_weights)
    w_loss_fn = get_loss_fn('weighted_loss', args.device, model_args.model_uncertainty,
        args.has_tasks_missing, class_weights=class_weights)

    # Get logger, evaluator and saver
    logger = TrainLogger(logger_args, args.start_epoch, args.num_epochs, args.batch_size,
        len(train_loader.dataset), args.device, normalization=transform_args.normalization)
    
    eval_args = {}
    eval_args['num_visuals'] = logger_args.num_visuals
    eval_args['iters_per_eval'] = logger_args.iters_per_eval
    eval_args['has_missing_tasks'] = args.has_tasks_missing
    eval_args['model_uncertainty'] = model_args.model_uncertainty
    eval_args['class_weights'] = class_weights
    eval_args['max_eval'] = logger_args.max_eval
    eval_args['device'] = args.device
    eval_args['optimizer'] = optimizer
    evaluator = get_evaluator('classification', eval_loaders, logger, eval_args)

    print("Eval Loaders: %d" % len(eval_loaders))
    saver = ModelSaver(**vars(logger_args))

    metrics = None
    lr_step = 0
    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, info_dict, covars in train_loader:
            logger.start_iter()

            # Evaluate and save periodically
            metrics, curves = evaluator.evaluate(model, args.device, logger.global_step)
            logger.plot_metrics(metrics)
            metric_val = metrics.get(logger_args.metric_name, None)
            assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None
            saver.save(logger.global_step, logger.epoch, model, optimizer, lr_scheduler, args.device,
                       metric_val=metric_val, covar_list=model_args.covar_list)
            lr_step = util.step_scheduler(lr_scheduler, metrics, lr_step, best_ckpt_metric=logger_args.metric_name)

            # Input: [batch_size, channels, width, height]

            with torch.set_grad_enabled(True):
            # with torch.autograd.set_detect_anomaly(True):

                logits = model.forward([inputs.to(args.device), covars])

                # Scale up TB so that it's loss is counted for more if upweight_tb is True.
                if model_args.upweight_tb is True:
                    tb_targets = targets.narrow(1, 0, 1)
                    findings_targets = targets.narrow(1, 1, targets.shape[1] - 1)
                    tb_targets = tb_targets.repeat(1, targets.shape[1] - 1)
                    new_targets = torch.cat((tb_targets, findings_targets), 1)

                    tb_logits = logits.narrow(1, 0, 1)
                    findings_logits = logits.narrow(1, 1, logits.shape[1] - 1)
                    tb_logits = tb_logits.repeat(1, logits.shape[1] - 1)
                    new_logits = torch.cat((tb_logits, findings_logits), 1)
                else:
                    new_logits = logits
                    new_targets = targets

                    
                unweighted_loss = uw_loss_fn(new_logits, new_targets.to(args.device))

                weighted_loss = w_loss_fn(logits, targets.to(args.device)) if w_loss_fn else None

                logger.log_iter(inputs, logits, targets, unweighted_loss, weighted_loss, optimizer)

                optimizer.zero_grad()
                if args.loss_fn == 'weighted_loss':
                    weighted_loss.backward()
                else:
                    unweighted_loss.backward()
                optimizer.step()

            logger.end_iter()

        logger.end_epoch(metrics, optimizer)