示例#1
0
def train(args):
    # Get model
    model = models.__dict__[args.model](args)
    if args.ckpt_path:
        model = ModelSaver.load_model(model, args.ckpt_path, args.gpu_ids, is_training=True)
    model = model.to(args.device)
    model.train()

    # Get loader, logger, and saver
    train_loader, val_loader = get_data_loaders(args)
    logger = TrainLogger(args, model, dataset_len=len(train_loader.dataset))
    saver = ModelSaver(args.save_dir, args.max_ckpts, metric_name=args.metric_name,
                       maximize_metric=args.maximize_metric, keep_topk=True)

    # Train
    while not logger.is_finished_training():
        logger.start_epoch()
        for batch in train_loader:
            logger.start_iter()

            # Train over one batch
            model.set_inputs(batch['src'], batch['tgt'])
            model.train_iter()

            logger.end_iter()

            # Evaluate
            if logger.global_step % args.iters_per_eval < args.batch_size:
                criteria = {'MSE_src2tgt': mse, 'MSE_tgt2src': mse}
                stats = evaluate(model, val_loader, criteria)
                logger.log_scalars({'val_' + k: v for k, v in stats.items()})
                saver.save(logger.global_step, model,
                           stats[args.metric_name], args.device)

        logger.end_epoch()
示例#2
0
def train(args):

    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = optim.get_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = nn.CrossEntropyLoss()
    train_loader = CIFARLoader('train', args.batch_size, args.num_workers)
    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [CIFARLoader('val', args.batch_size, args.num_workers)]
    evaluator = ModelEvaluator(eval_loaders, logger, args.max_eval,
                               args.epochs_per_eval)
    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                logits = model.forward(inputs.to(args.device))
                loss = loss_fn(logits, targets.to(args.device))

                logger.log_iter(loss)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics)
        optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
示例#3
0
def train(args):
    train_loader = get_loader(args=args)
    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        args.D_in = train_loader.D_in
        model = model_fn(**vars(args))
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = optim.get_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = optim.get_loss_fn(args.loss_fn, args)

    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [
        get_loader(args, phase='train', is_training=False),
        get_loader(args, phase='valid', is_training=False)
    ]
    evaluator = ModelEvaluator(args, eval_loaders, logger, args.max_eval,
                               args.epochs_per_eval)

    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for src, tgt in train_loader:
            logger.start_iter()
            with torch.set_grad_enabled(True):
                pred_params = model.forward(src.to(args.device))
                ages = src[:, 1]
                loss = loss_fn(pred_params, tgt.to(args.device),
                               ages.to(args.device), args.use_intvl)
                #loss = loss_fn(pred_params, tgt.to(args.device), src.to(args.device), args.use_intvl)
                logger.log_iter(src, pred_params, tgt, loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        # print(metrics)
        saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device,\
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics=metrics)
示例#4
0
def train(args):

    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Set up population-based training client
    pbt_client = PBTClient(args.pbt_server_url, args.pbt_server_port, args.pbt_server_key, args.pbt_config_path)

    # Get optimizer and scheduler
    parameters = model.module.parameters()
    optimizer = optim.get_optimizer(parameters, args, pbt_client)
    ModelSaver.load_optimizer(args.ckpt_path, args.gpu_ids, optimizer)

    # Get logger, evaluator, saver
    train_loader = DataLoader(args, 'train', is_training_set=True)
    eval_loaders = [DataLoader(args, 'valid', is_training_set=False)]
    evaluator = ModelEvaluator(eval_loaders, args.epochs_per_eval,
                               args.max_eval, args.num_visuals, use_ten_crop=args.use_ten_crop)
    saver = ModelSaver(**vars(args))

    for _ in range(args.num_epochs):
        optim.update_hyperparameters(model.module, optimizer, pbt_client.hyperparameters())

        for inputs, targets in train_loader:
            with torch.set_grad_enabled(True):
                logits = model.forward(inputs.to(args.device))
                loss = F.binary_cross_entropy_with_logits(logits, targets.to(args.device))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        metrics = evaluator.evaluate(model, args.device)
        metric_val = metrics.get(args.metric_name, None)
        ckpt_path = saver.save(model, args.model, optimizer, args.device, metric_val)

        pbt_client.save(ckpt_path, metric_val)
        if pbt_client.should_exploit():
            # Exploit
            pbt_client.exploit()

            # Load model and optimizer parameters from exploited network
            model, ckpt_info = ModelSaver.load_model(pbt_client.parameters_path(), args.gpu_ids)
            model = model.to(args.device)
            model.train()
            ModelSaver.load_optimizer(pbt_client.parameters_path(), args.gpu_ids, optimizer)

            # Explore
            pbt_client.explore()
def train(args):
    """Train the model on the dataset."""
    Dataset = collections.namedtuple('Dataset', 'X y')

    assert args.random or args.csv_path, "Please choose either random data or pass a data path"

    if args.random:
        train_features = sp.random(100, 100)
        train_costs = np.random.random((100, ))

        test_features = np.random.random((100, 100))
        test_costs = np.random.random((100, ))
    else:
        train_path = Path(args.train_path)
        train_df = pd.read_csv(train_path)
        train_features, train_costs = preprocess(train_df, args.sdh)

        test_path = Path(args.test_path)
        test_df = pd.read_csv(test_path)
        test_features, test_costs = preprocess(test_df, args.sdh)

    train_dataset = Dataset(train_features, train_costs)
    test_dataset = Dataset(test_features, test_costs)

    # Load the model.
    saver = ModelSaver(args.save_dir)
    model = get_model(args)

    # Instantiate the model saver.
    # Train model on dataset, cross-validating on validation set.
    model.fit(train_dataset)
    saver.save(model)

    preds, targets = predict(model, test_dataset)

    metrics = {
        "R2-score": r2_score(targets, preds),
        "MAE": mean_absolute_error(targets, preds)
    }

    # Print metrics to stdout.
    print(metrics)
def train(args):

    if args.ckpt_path and not args.use_pretrained:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(**vars(args))
        if args.use_pretrained:
            model.load_pretrained(args.ckpt_path, args.gpu_ids)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    if args.use_pretrained or args.fine_tune:
        parameters = model.module.fine_tuning_parameters(
            args.fine_tuning_boundary, args.fine_tuning_lr)
    else:
        parameters = model.parameters()
    optimizer = util.get_optimizer(parameters, args)
    lr_scheduler = util.get_scheduler(optimizer, args)
    if args.ckpt_path and not args.use_pretrained and not args.fine_tune:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    cls_loss_fn = util.get_loss_fn(is_classification=True,
                                   dataset=args.dataset,
                                   size_average=False)
    data_loader_fn = data_loader.__dict__[args.data_loader]
    train_loader = data_loader_fn(args, phase='train', is_training=True)
    logger = TrainLogger(args, len(train_loader.dataset),
                         train_loader.dataset.pixel_dict)
    eval_loaders = [data_loader_fn(args, phase='val', is_training=False)]
    evaluator = ModelEvaluator(args.do_classify, args.dataset, eval_loaders,
                               logger, args.agg_method, args.num_visuals,
                               args.max_eval, args.epochs_per_eval)
    saver = ModelSaver(args.save_dir, args.epochs_per_save, args.max_ckpts,
                       args.best_ckpt_metric, args.maximize_metric)

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, target_dict in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                inputs.to(args.device)
                cls_logits = model.forward(inputs)
                cls_targets = target_dict['is_abnormal']
                cls_loss = cls_loss_fn(cls_logits, cls_targets.to(args.device))
                loss = cls_loss.mean()

                logger.log_iter(inputs, cls_logits, target_dict,
                                cls_loss.mean(), optimizer)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter()
            util.step_scheduler(lr_scheduler, global_step=logger.global_step)

        metrics, curves = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.best_ckpt_metric, None))
        logger.end_epoch(metrics, curves)
        util.step_scheduler(lr_scheduler,
                            metrics,
                            epoch=logger.epoch,
                            best_ckpt_metric=args.best_ckpt_metric)
示例#7
0
def train(args):
    # Get loader for outer loop training
    loader = get_loader(args)
    target_image_shape = loader.dataset.target_image_shape
    setattr(args, 'target_image_shape', target_image_shape)

    # Load model
    model_fn = models.__dict__[args.model]
    model = model_fn(**vars(args))
    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Print model parameters
    print('Model parameters: name, size, mean, std')
    for name, param in model.named_parameters():
        print(name, param.size(), torch.mean(param), torch.std(param))

    # Get optimizer and loss
    parameters = model.parameters()
    optimizer = util.get_optimizer(parameters, args)
    loss_fn = util.get_loss_fn(args.loss_fn, args)

    z_loss_fn = util.get_loss_fn(args.loss_fn, args)

    # Get logger, saver
    logger = TrainLogger(args)
    saver = ModelSaver(args)

    print(f'Logs: {logger.log_dir}')
    print(f'Ckpts: {args.save_dir}')

    # Train model
    logger.log_hparams(args)
    batch_size = args.batch_size
    while not logger.is_finished_training():
        logger.start_epoch()

        for input_noise, target_image, mask, z_test_target, z_test in loader:
            logger.start_iter()

            if torch.cuda.is_available():
                input_noise = input_noise.to(args.device)  #.cuda()
                target_image = target_image.cuda()
                mask = mask.cuda()
                z_test = z_test.cuda()
                z_test_target = z_test_target.cuda()

            masked_target_image = target_image * mask
            obscured_target_image = target_image * (1.0 - mask)

            # Input is noise tensor, target is image
            model.train()
            with torch.set_grad_enabled(True):
                if args.use_intermediate_logits:
                    logits = model.forward(input_noise).float()
                    probs = F.sigmoid(logits)

                    # Debug logits and diffs
                    logger.debug_visualize(
                        [logits, logits * mask, logits * (1.0 - mask)],
                        unique_suffix='logits-train')
                else:
                    probs = model.forward(input_noise).float()

                # With backprop, calculate (1) masked loss, loss when mask is applied.
                # Loss is done elementwise without reduction, so must take mean after.
                # Easier for debugging.
                masked_probs = probs * mask
                masked_loss = torch.zeros(1,
                                          requires_grad=True).to(args.device)
                masked_loss = loss_fn(masked_probs, masked_target_image).mean()

                masked_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            # Without backprop, calculate (2) full loss on the entire image,
            # And (3) the obscured loss, region obscured by mask.
            model.eval()
            with torch.no_grad():
                if args.use_intermediate_logits:
                    logits_eval = model.forward(input_noise).float()
                    probs_eval = F.sigmoid(logits_eval)

                    # Debug logits and diffs
                    logger.debug_visualize([
                        logits_eval, logits_eval * mask, logits_eval *
                        (1.0 - mask)
                    ],
                                           unique_suffix='logits-eval')
                else:
                    probs_eval = model.forward(input_noise).float()

                masked_probs_eval = probs_eval * mask
                masked_loss_eval = torch.zeros(1)
                masked_loss_eval = loss_fn(masked_probs_eval,
                                           masked_target_image).mean()

                full_loss_eval = torch.zeros(1)
                full_loss_eval = loss_fn(probs_eval, target_image).mean()

                obscured_probs_eval = probs_eval * (1.0 - mask)
                obscured_loss_eval = torch.zeros(1)
                obscured_loss_eval = loss_fn(obscured_probs_eval,
                                             obscured_target_image).mean()

            # With backprop on only the input z, (4) run one step of z-test and get z-loss
            z_optimizer = util.get_optimizer([z_test.requires_grad_()], args)
            with torch.set_grad_enabled(True):
                if args.use_intermediate_logits:
                    z_logits = model.forward(z_test).float()
                    z_probs = F.sigmoid(z_logits)
                else:
                    z_probs = model.forward(z_test).float()

                z_loss = torch.zeros(1, requires_grad=True).to(args.device)
                z_loss = z_loss_fn(z_probs, z_test_target).mean()

                z_loss.backward()
                z_optimizer.step()
                z_optimizer.zero_grad()

            if z_loss < args.max_z_test_loss:  # TODO: include this part into the metrics/saver stuff below
                # Save MSE on obscured region
                final_metrics = {'final/score': obscured_loss_eval.item()}
                logger._log_scalars(final_metrics)
                print('z loss', z_loss)
                print('Final MSE value', obscured_loss_eval)

            # TODO: Make a function for metrics - or at least make sure dict includes all possible best ckpt metrics
            metrics = {'masked_loss': masked_loss.item()}
            saver.save(logger.global_step,
                       model,
                       optimizer,
                       args.device,
                       metric_val=metrics.get(args.best_ckpt_metric, None))
            # Log both train and eval model settings, and visualize their outputs
            logger.log_status(
                inputs=input_noise,
                targets=target_image,
                probs=probs,
                masked_probs=masked_probs,
                masked_loss=masked_loss,
                probs_eval=probs_eval,
                masked_probs_eval=masked_probs_eval,
                obscured_probs_eval=obscured_probs_eval,
                masked_loss_eval=masked_loss_eval,
                obscured_loss_eval=obscured_loss_eval,
                full_loss_eval=full_loss_eval,
                z_target=z_test_target,
                z_probs=z_probs,
                z_loss=z_loss,
                save_preds=args.save_preds,
            )

            logger.end_iter()

        logger.end_epoch()

    # Last log after everything completes
    logger.log_status(
        inputs=input_noise,
        targets=target_image,
        probs=probs,
        masked_probs=masked_probs,
        masked_loss=masked_loss,
        probs_eval=probs_eval,
        masked_probs_eval=masked_probs_eval,
        obscured_probs_eval=obscured_probs_eval,
        masked_loss_eval=masked_loss_eval,
        obscured_loss_eval=obscured_loss_eval,
        full_loss_eval=full_loss_eval,
        z_target=z_test_target,
        z_probs=z_probs,
        z_loss=z_loss,
        save_preds=args.save_preds,
        force_visualize=True,
    )
示例#8
0
def train(args):
    """Run model training."""

    print("Start Training ...")

    # Get nested namespaces.
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    # Get logger.
    print('Getting logger... log to path: {}'.format(logger_args.log_path))
    logger = Logger(logger_args.log_path, logger_args.save_dir)

    # For conaug, point to the MOCO pretrained weights.
    if model_args.ckpt_path and model_args.ckpt_path != 'None':
        print("pretrained checkpoint specified : {}".format(
            model_args.ckpt_path))
        # CL-specified args are used to load the model, rather than the
        # ones saved to args.json.
        model_args.pretrained = False
        ckpt_path = model_args.ckpt_path
        model, ckpt_info = ModelSaver.load_model(ckpt_path=ckpt_path,
                                                 gpu_ids=args.gpu_ids,
                                                 model_args=model_args,
                                                 is_training=True)

        if not model_args.moco:
            optim_args.start_epoch = ckpt_info['epoch'] + 1
        else:
            optim_args.start_epoch = 1
    else:
        print(
            'Starting without pretrained training checkpoint, random initialization.'
        )
        # If no ckpt_path is provided, instantiate a new randomly
        # initialized model.
        model_fn = models.__dict__[model_args.model]
        if data_args.custom_tasks is not None:
            tasks = NamedTasks[data_args.custom_tasks]
        else:
            tasks = model_args.__dict__[TASKS]  # TASKS = "tasks"
        print("Tasks: {}".format(tasks))
        model = model_fn(tasks, model_args)
        model = nn.DataParallel(model, args.gpu_ids)

    # Put model on gpu or cpu and put into training mode.
    model = model.to(args.device)
    model.train()

    print("========= MODEL ==========")
    print(model)

    # Get train and valid loader objects.
    train_loader = get_loader(phase="train",
                              data_args=data_args,
                              transform_args=transform_args,
                              is_training=True,
                              return_info_dict=False,
                              logger=logger)
    valid_loader = get_loader(phase="valid",
                              data_args=data_args,
                              transform_args=transform_args,
                              is_training=False,
                              return_info_dict=False,
                              logger=logger)

    # Instantiate the predictor class for obtaining model predictions.
    predictor = Predictor(model, args.device)
    # Instantiate the evaluator class for evaluating models.
    evaluator = Evaluator(logger)
    # Get the set of tasks which will be used for saving models
    # and annealing learning rate.
    eval_tasks = EVAL_METRIC2TASKS[optim_args.metric_name]

    # Instantiate the saver class for saving model checkpoints.
    saver = ModelSaver(save_dir=logger_args.save_dir,
                       iters_per_save=logger_args.iters_per_save,
                       max_ckpts=logger_args.max_ckpts,
                       metric_name=optim_args.metric_name,
                       maximize_metric=optim_args.maximize_metric,
                       keep_topk=logger_args.keep_topk)

    # TODO: JBY: handle threshold for fine tuning
    if model_args.fine_tuning == 'full':  # Fine tune all layers.
        pass
    else:
        # Freeze other layers.
        models.PretrainedModel.set_require_grad_for_fine_tuning(
            model, model_args.fine_tuning.split(','))

    # Instantiate the optimizer class for guiding model training.
    optimizer = Optimizer(parameters=model.parameters(),
                          optim_args=optim_args,
                          batch_size=data_args.batch_size,
                          iters_per_print=logger_args.iters_per_print,
                          iters_per_visual=logger_args.iters_per_visual,
                          iters_per_eval=logger_args.iters_per_eval,
                          dataset_len=len(train_loader.dataset),
                          logger=logger)

    if model_args.ckpt_path and not model_args.moco:
        # Load the same optimizer as used in the original training.
        optimizer.load_optimizer(ckpt_path=model_args.ckpt_path,
                                 gpu_ids=args.gpu_ids)

    model_uncertainty = model_args.model_uncertainty
    loss_fn = evaluator.get_loss_fn(
        loss_fn_name=optim_args.loss_fn,
        model_uncertainty=model_args.model_uncertainty,
        mask_uncertain=True,
        device=args.device)

    # Run training
    while not optimizer.is_finished_training():
        optimizer.start_epoch()

        # TODO: JBY, HACK WARNING  # What is the hack?
        metrics = None
        for inputs, targets in train_loader:
            optimizer.start_iter()
            if optimizer.global_step and optimizer.global_step % optimizer.iters_per_eval == 0 or len(
                    train_loader.dataset
            ) - optimizer.iter < optimizer.batch_size:

                # Only evaluate every iters_per_eval examples.
                predictions, groundtruth = predictor.predict(valid_loader)
                # print("predictions: {}".format(predictions))
                metrics, curves = evaluator.evaluate_tasks(
                    groundtruth, predictions)
                # Log metrics to stdout.
                logger.log_metrics(metrics)

                # Add logger for all the metrics for valid_loader
                logger.log_scalars(metrics, optimizer.global_step)

                # Get the metric used to save model checkpoints.
                average_metric = evaluator.evaluate_average_metric(
                    metrics, eval_tasks, optim_args.metric_name)

                if optimizer.global_step % logger_args.iters_per_save == 0:
                    # Only save every iters_per_save examples directly
                    # after evaluation.
                    print("Save global step: {}".format(optimizer.global_step))
                    saver.save(iteration=optimizer.global_step,
                               epoch=optimizer.epoch,
                               model=model,
                               optimizer=optimizer,
                               device=args.device,
                               metric_val=average_metric)

                # Step learning rate scheduler.
                optimizer.step_scheduler(average_metric)

            with torch.set_grad_enabled(True):
                logits, embedding = model(inputs.to(args.device))
                loss = loss_fn(logits, targets.to(args.device))
                optimizer.log_iter(inputs, logits, targets, loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            optimizer.end_iter()

        optimizer.end_epoch(metrics)

    logger.log('=== Training Complete ===')
示例#9
0
文件: train.py 项目: meajagun/papers
def train(args):
    """Train model.

    Args:
        args: Command line arguments.
        model: Classifier model to train.
    """
    # Set up model
    model = models.__dict__[args.model](**vars(args))
    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)

    # Set up data loader
    train_loader, test_loader, classes = get_cifar_loaders(
        args.batch_size, args.num_workers)

    # Set up optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.sgd_momentum,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step,
                                          args.lr_decay_gamma)
    loss_fn = nn.CrossEntropyLoss().to(args.device)

    # Set up checkpoint saver
    saver = ModelSaver(model,
                       optimizer,
                       scheduler,
                       args.save_dir, {'model': args.model},
                       max_to_keep=args.max_ckpts,
                       device=args.device)

    # Train
    logger = TrainLogger(args, len(train_loader.dataset))

    while not logger.is_finished_training():
        logger.start_epoch()

        # Train for one epoch
        model.train()
        for inputs, labels in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                # Forward
                outputs = model.forward(inputs.to(args.device))
                loss = loss_fn(outputs, labels.to(args.device))
                loss_item = loss.item()

                # Backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter({'loss': loss_item})

        # Evaluate on validation set
        val_loss = evaluate(model, test_loader, loss_fn, device=args.device)
        logger.write('[epoch {}]: val_loss: {:.3g}'.format(
            logger.epoch, val_loss))
        logger.write_summaries({'loss': val_loss}, phase='val')
        if logger.epoch in args.save_epochs:
            saver.save(logger.epoch, val_loss)

        logger.end_epoch()
        scheduler.step()
示例#10
0
def train(args):
    """Run model training."""

    # Get nested namespaces.
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args

    # Get logger.
    logger = Logger(logger_args)

    if model_args.ckpt_path:
        # CL-specified args are used to load the model, rather than the
        # ones saved to args.json.
        model_args.pretrained = False
        ckpt_path = model_args.ckpt_path
        assert False
        model, ckpt_info = ModelSaver.load_model(ckpt_path=ckpt_path,
                                                 gpu_ids=args.gpu_ids,
                                                 model_args=model_args,
                                                 is_training=True)
        optim_args.start_epoch = ckpt_info['epoch'] + 1
    else:
        # If no ckpt_path is provided, instantiate a new randomly
        # initialized model.
        model_fn = models.__dict__[model_args.model]
        model = model_fn(model_args)
        model = nn.DataParallel(model, args.gpu_ids)
    # Put model on gpu or cpu and put into training mode.
    model = model.to(args.device)
    model.train()

    # Get train and valid loader objects.
    train_loader = get_loader(phase="train",
                              data_args=data_args,
                              is_training=True,
                              logger=logger)
    valid_loader = get_loader(phase="valid",
                              data_args=data_args,
                              is_training=False,
                              logger=logger)
    dense_valid_loader = get_loader(phase="dense_valid",
                                    data_args=data_args,
                                    is_training=False,
                                    logger=logger)

    # Instantiate the predictor class for obtaining model predictions.
    predictor = Predictor(model, args.device)

    # Instantiate the evaluator class for evaluating models.
    # By default, get best performance on validation set.
    evaluator = Evaluator(logger=logger, tune_threshold=True)

    # Instantiate the saver class for saving model checkpoints.
    saver = ModelSaver(save_dir=logger_args.save_dir,
                       iters_per_save=logger_args.iters_per_save,
                       max_ckpts=logger_args.max_ckpts,
                       metric_name=optim_args.metric_name,
                       maximize_metric=optim_args.maximize_metric,
                       keep_topk=True,
                       logger=logger)

    # Instantiate the optimizer class for guiding model training.
    optimizer = Optimizer(parameters=model.parameters(),
                          optim_args=optim_args,
                          batch_size=data_args.batch_size,
                          iters_per_print=logger_args.iters_per_print,
                          iters_per_visual=logger_args.iters_per_visual,
                          iters_per_eval=logger_args.iters_per_eval,
                          dataset_len=len(train_loader.dataset),
                          logger=logger)
    if model_args.ckpt_path:
        # Load the same optimizer as used in the original training.
        optimizer.load_optimizer(ckpt_path=model_args.ckpt_path,
                                 gpu_ids=args.gpu_ids)

    loss_fn = evaluator.get_loss_fn(loss_fn_name=optim_args.loss_fn)

    # Run training
    while not optimizer.is_finished_training():
        optimizer.start_epoch()

        for inputs, targets in train_loader:
            optimizer.start_iter()

            if optimizer.global_step % optimizer.iters_per_eval == 0:
                # Only evaluate every iters_per_eval examples.
                predictions, groundtruth = predictor.predict(valid_loader)
                metrics = evaluator.evaluate(groundtruth, predictions)

                # Evaluate on dense dataset
                dense_predictions, dense_groundtruth = predictor.predict(
                    dense_valid_loader)
                dense_metrics = evaluator.dense_evaluate(
                    dense_groundtruth, dense_predictions)
                # Merge the metrics dicts together
                metrics = {**metrics, **dense_metrics}

                # Log metrics to stdout.
                logger.log_metrics(metrics, phase='valid')

                # Log to tb
                logger.log_scalars(metrics,
                                   optimizer.global_step,
                                   phase='valid')

                if optimizer.global_step % logger_args.iters_per_save == 0:
                    # Only save every iters_per_save examples directly
                    # after evaluation.
                    saver.save(iteration=optimizer.global_step,
                               epoch=optimizer.epoch,
                               model=model,
                               optimizer=optimizer,
                               device=args.device,
                               metric_val=metrics[optim_args.metric_name])

                # Step learning rate scheduler.
                optimizer.step_scheduler(metrics[optim_args.metric_name])

            with torch.set_grad_enabled(True):

                # Run the minibatch through the model.
                logits = model(inputs.to(args.device))

                # Compute the minibatch loss.
                loss = loss_fn(logits, targets.to(args.device))

                # Log the data from this iteration.
                optimizer.log_iter(inputs, logits, targets, loss)

                # Perform a backward pass.
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            optimizer.end_iter()

        optimizer.end_epoch(metrics)

    # Save the most recent model.
    saver.save(iteration=optimizer.global_step,
               epoch=optimizer.epoch,
               model=model,
               optimizer=optimizer,
               device=args.device,
               metric_val=metrics[optim_args.metric_name])
示例#11
0
文件: train.py 项目: yxliang/lca-code
def train(args):
    """Run training loop with the given args.

    The function consists of the following steps:
        1. Load model: gets the model from a checkpoint or from models/models.py.
        2. Load optimizer and learning rate scheduler.
        3. Get data loaders and class weights.
        4. Get loss functions: cross entropy loss and weighted loss functions.
        5. Get logger, evaluator, and saver.
        6. Run training loop, evaluate and save model periodically.
    """

    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    task_sequence = TASK_SEQUENCES[data_args.task_sequence]

    # Get model
    if model_args.ckpt_path:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path,
                                                 args.gpu_ids, model_args,
                                                 data_args)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[model_args.model]
        model = model_fn(task_sequence, model_args)
        if model_args.hierarchy:
            model = models.HierarchyWrapper(model, task_sequence)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = util.get_optimizer(model.parameters(), optim_args)
    lr_scheduler = util.get_scheduler(optimizer, optim_args)
    if model_args.ckpt_path:
        ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids,
                                  optimizer, lr_scheduler)

    # Get loaders and class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path
    #TODO: Remove this when we decide which transformation to use in the end
    #transforms_imgaug = ImgAugTransform()
    train_loader = get_loader(data_args,
                              transform_args,
                              train_csv_name,
                              task_sequence,
                              data_args.su_train_frac,
                              data_args.nih_train_frac,
                              data_args.pocus_train_frac,
                              data_args.tcga_train_frac,
                              0,
                              0,
                              args.batch_size,
                              frontal_lateral=model_args.frontal_lateral,
                              is_training=True,
                              shuffle=True,
                              transform=model_args.transform,
                              normalize=model_args.normalize)
    eval_loaders = get_eval_loaders(data_args,
                                    transform_args,
                                    task_sequence,
                                    args.batch_size,
                                    frontal_lateral=model_args.frontal_lateral,
                                    normalize=model_args.normalize)
    class_weights = train_loader.dataset.class_weights
    print(" class weights:")
    print(class_weights)

    # Get loss functions
    uw_loss_fn = get_loss_fn('cross_entropy',
                             args.device,
                             model_args.model_uncertainty,
                             args.has_tasks_missing,
                             class_weights=class_weights)

    w_loss_fn = get_loss_fn('weighted_loss',
                            args.device,
                            model_args.model_uncertainty,
                            args.has_tasks_missing,
                            mask_uncertain=False,
                            class_weights=class_weights)

    # Get logger, evaluator and saver
    logger = TrainLogger(logger_args, args.start_epoch,
                         args.num_epochs, args.batch_size,
                         len(train_loader.dataset), args.device)

    eval_args = {}
    eval_args['num_visuals'] = logger_args.num_visuals
    eval_args['iters_per_eval'] = logger_args.iters_per_eval
    eval_args['has_missing_tasks'] = args.has_tasks_missing
    eval_args['model_uncertainty'] = model_args.model_uncertainty
    eval_args['class_weights'] = class_weights
    eval_args['max_eval'] = logger_args.max_eval
    eval_args['device'] = args.device
    eval_args['optimizer'] = args.optimizer
    evaluator = get_evaluator('classification', eval_loaders, logger,
                              eval_args)

    print("Eval Loaders: %d" % len(eval_loaders))
    saver = ModelSaver(**vars(logger_args))

    metrics = None
    lr_step = 0
    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, info_dict in train_loader:

            logger.start_iter()

            # Evaluate and save periodically
            metrics, curves = evaluator.evaluate(model, args.device,
                                                 logger.global_step)
            logger.plot_metrics(metrics)
            metric_val = metrics.get(logger_args.metric_name, None)

            assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None
            saver.save(logger.global_step,
                       logger.epoch,
                       model,
                       optimizer,
                       lr_scheduler,
                       args.device,
                       metric_val=metric_val)
            lr_step = util.step_scheduler(
                lr_scheduler,
                metrics,
                lr_step,
                best_ckpt_metric=logger_args.metric_name)

            # Input: [batch_size, channels, width, height]

            with torch.set_grad_enabled(True):

                logits = model.forward(inputs.to(args.device))

                unweighted_loss = uw_loss_fn(logits, targets.to(args.device))

                weighted_loss = w_loss_fn(logits, targets.to(
                    args.device)) if w_loss_fn else None

                logger.log_iter(inputs, logits, targets, unweighted_loss,
                                weighted_loss, optimizer)

                optimizer.zero_grad()
                if args.loss_fn == 'weighted_loss':
                    weighted_loss.backward()
                else:
                    unweighted_loss.backward()
                optimizer.step()

            logger.end_iter()

        logger.end_epoch(metrics, optimizer)
示例#12
0
def train(args):
    write_args(args)

    model_args = args.model_args
    data_args = args.data_args
    logger_args = args.logger_args

    print(f"Training {logger_args.name}")

    power_constraint = PowerConstraint()
    possible_inputs = get_md_set(model_args.md_len)
    channel = get_channel(data_args.channel, model_args.modelfree, data_args)

    model = AutoEncoder(model_args, data_args, power_constraint, channel,
                        possible_inputs)
    enc_scheduler = get_scheduler(model_args.scheduler, model_args.decay,
                                  model_args.patience)
    dec_scheduler = get_scheduler(model_args.scheduler, model_args.decay,
                                  model_args.patience)

    enc_scheduler.set_model(model.trainable_encoder)
    dec_scheduler.set_model(model.trainable_decoder)
    dataset_size = data_args.batch_size * data_args.batches_per_epoch * data_args.num_epochs
    loader = InputDataloader(data_args.batch_size, data_args.block_length,
                             dataset_size)
    loader = loader.example_generator()
    logger = TrainLogger(logger_args.save_dir, logger_args.name,
                         data_args.num_epochs, logger_args.iters_per_print)

    saver = ModelSaver(logger_args.save_dir, logger)

    enc_scheduler.on_train_begin()
    dec_scheduler.on_train_begin()

    while True:  # Loop until StopIteration
        try:
            metrics = None
            logger.start_epoch()
            for step in range(data_args.batches_per_epoch //
                              (model_args.train_ratio + 1)):
                # encoder train
                logger.start_iter()
                msg = next(loader)
                metrics = model.train_encoder(msg)
                logger.log_iter(metrics)
                logger.end_iter()

                # decoder train
                for _ in range(model_args.train_ratio):
                    logger.start_iter()
                    msg = next(loader)
                    metrics = model.train_decoder(msg)
                    logger.log_iter(metrics)
                    logger.end_iter()
            logger.end_epoch(None)

            if model_args.modelfree:
                model.Pi.std *= model_args.sigma_decay

            enc_scheduler.on_epoch_end(logger.epoch, logs=metrics)
            dec_scheduler.on_epoch_end(logger.epoch, logs=metrics)

            if logger.has_improved():
                saver.save(model)

            if logger.notImprovedCounter >= 7:
                break
        except StopIteration:
            break
示例#13
0
def train_classifier(args, model):
    """Train a classifier and save its first-layer weights.

    Args:
        args: Command line arguments.
        model: Classifier model to train.
    """
    # Set up data loader
    train_loader, test_loader, classes = get_data_loaders(
        args.dataset, args.batch_size, args.num_workers)

    # Set up model
    model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)

    fd = None
    if args.use_fd:
        fd = models.filter_discriminator()
        fd = nn.DataParallel(fd, args.gpu_ids)
        fd = fd.to(args.device)

    # Set up optimizer
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.sgd_momentum,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step,
                                          args.lr_decay_gamma)
    if args.model == 'fd':
        post_process = nn.Sigmoid()
        loss_fn = nn.MSELoss().to(args.device)
    else:
        post_process = nn.Sequential()  # Identity
        loss_fn = nn.CrossEntropyLoss().to(args.device)

    # Set up checkpoint saver
    saver = ModelSaver(model,
                       optimizer,
                       scheduler,
                       args.save_dir, {'model': args.model},
                       max_to_keep=args.max_ckpts,
                       device=args.device)

    # Train
    logger = TrainLogger(args, len(train_loader.dataset))
    if args.save_all:
        # Save initialized model weights with validation loss as random
        saver.save(0, math.log(args.num_classes))
    while not logger.is_finished_training():
        logger.start_epoch()

        # Train for one epoch
        model.train()
        fd_lambda = get_fd_lambda(args, logger.epoch)
        for inputs, labels in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                # Forward
                outputs = model.forward(inputs.to(args.device))
                outputs = post_process(outputs)
                loss = loss_fn(outputs, labels.to(args.device))
                loss_item = loss.item()

                fd_loss = torch.zeros([],
                                      dtype=torch.float32,
                                      device='cuda' if args.gpu_ids else 'cpu')
                tp_total = torch.zeros(
                    [],
                    dtype=torch.float32,
                    device='cuda' if args.gpu_ids else 'cpu')
                if fd is not None:
                    # Forward FD
                    filters = get_layer_weights(model, filter_dict[args.model])
                    for i in range(0, filters.size(0), args.fd_batch_size):
                        fd_batch = filters[i:i + args.fd_batch_size]
                        tp_scores = F.sigmoid(fd.forward(fd_batch))
                        tp_total += tp_scores.sum()
                    fd_loss = 1. - tp_total / filters.size(0)

                fd_loss_item = fd_loss.item()
                loss += fd_lambda * fd_loss

                # Backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            logger.end_iter({
                'std_loss': loss_item,
                'fd_loss': fd_loss_item,
                'loss': loss_item + fd_loss_item
            })

        # Evaluate on validation set
        val_loss = evaluate(model,
                            post_process,
                            test_loader,
                            loss_fn,
                            device=args.device)
        logger.write('[epoch {}]: val_loss: {:.3g}'.format(
            logger.epoch, val_loss))
        logger.write_summaries({'loss': val_loss}, phase='val')
        if args.save_all or logger.epoch in args.save_epochs:
            saver.save(logger.epoch, val_loss)

        logger.end_epoch()
        scheduler.step()
示例#14
0
def train(args):

    if args.ckpt_path:
        model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
        args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[args.model]
        model = model_fn(pretrained=args.pretrained)
        if args.pretrained:
            model.fc = nn.Linear(model.fc.in_features, args.num_classes)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    parameters = optim.get_parameters(model.module, args)
    optimizer = optim.get_optimizer(parameters, args)
    lr_scheduler = optim.get_scheduler(optimizer, args)
    if args.ckpt_path:
        ModelSaver.load_optimizer(args.ckpt_path, optimizer, lr_scheduler)

    # Get logger, evaluator, saver
    loss_fn = nn.CrossEntropyLoss()
    train_loader = WhiteboardLoader(args.data_dir,
                                    'train',
                                    args.batch_size,
                                    shuffle=True,
                                    do_augment=True,
                                    num_workers=args.num_workers)
    logger = TrainLogger(args, len(train_loader.dataset))
    eval_loaders = [
        WhiteboardLoader(args.data_dir,
                         'val',
                         args.batch_size,
                         shuffle=False,
                         do_augment=False,
                         num_workers=args.num_workers)
    ]
    evaluator = ModelEvaluator(eval_loaders, logger, args.epochs_per_eval,
                               args.max_eval, args.num_visuals)
    saver = ModelSaver(**vars(args))

    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, paths in train_loader:
            logger.start_iter()

            with torch.set_grad_enabled(True):
                logits = model.forward(inputs.to(args.device))
                loss = loss_fn(logits, targets.to(args.device))

                logger.log_iter(inputs, logits, targets, paths, loss)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            optim.step_scheduler(lr_scheduler, global_step=logger.global_step)
            logger.end_iter()

        metrics = evaluator.evaluate(model, args.device, logger.epoch)
        saver.save(logger.epoch,
                   model,
                   args.model,
                   optimizer,
                   lr_scheduler,
                   args.device,
                   metric_val=metrics.get(args.metric_name, None))
        logger.end_epoch(metrics)
        optim.step_scheduler(lr_scheduler, metrics, logger.epoch)
示例#15
0
def train(args):
    """Run training loop with the given args.

    The function consists of the following steps:
        1. Load model: gets the model from a checkpoint or from models/models.py.
        2. Load optimizer and learning rate scheduler.
        3. Get data loaders and class weights.
        4. Get loss functions: cross entropy loss and weighted loss functions.
        5. Get logger, evaluator, and saver.
        6. Run training loop, evaluate and save model periodically.
    """
    model_args = args.model_args
    logger_args = args.logger_args
    optim_args = args.optim_args
    data_args = args.data_args
    transform_args = args.transform_args

    task_sequence = TASK_SEQUENCES[data_args.task_sequence]
    print('gpus: ', args.gpu_ids)
    # Get model
    if model_args.ckpt_path:
        model_args.pretrained = False
        model, ckpt_info = ModelSaver.load_model(model_args.ckpt_path, args.gpu_ids, model_args, data_args)
        if not logger_args.restart_epoch_count:
            args.start_epoch = ckpt_info['epoch'] + 1
    else:
        model_fn = models.__dict__[model_args.model]
        model = model_fn(task_sequence, model_args)
        num_covars = len(model_args.covar_list.split(';'))
        model.transform_model_shape(len(task_sequence), num_covars)
        if model_args.hierarchy:
            model = models.HierarchyWrapper(model, task_sequence)
        model = nn.DataParallel(model, args.gpu_ids)
    model = model.to(args.device)
    model.train()

    # Get optimizer and scheduler
    optimizer = util.get_optimizer(model.parameters(), optim_args)
    lr_scheduler = util.get_scheduler(optimizer, optim_args)

    # The optimizer is loaded from the ckpt if one exists and the new model
    # architecture is the same as the old one (classifier is not transformed).
    if model_args.ckpt_path and not model_args.transform_classifier:
        ModelSaver.load_optimizer(model_args.ckpt_path, args.gpu_ids, optimizer, lr_scheduler)

    # Get loaders and class weights
    train_csv_name = 'train'
    if data_args.uncertain_map_path is not None:
        train_csv_name = data_args.uncertain_map_path

    # Put all CXR training fractions into one dictionary and pass it to the loader
    cxr_frac = {'pocus': data_args.pocus_train_frac, 'hocus': data_args.hocus_train_frac,
                'pulm': data_args.pulm_train_frac}
    train_loader = get_loader(data_args,
                              transform_args,
                              train_csv_name,
                              task_sequence,
                              data_args.su_train_frac,
                              data_args.nih_train_frac,
                              cxr_frac,
                              data_args.tcga_train_frac,
                              args.batch_size,
                              frontal_lateral=model_args.frontal_lateral,
                              is_training=True,
                              shuffle=True,
                              covar_list=model_args.covar_list,
                              fold_num=data_args.fold_num)
    eval_loaders = get_eval_loaders(data_args,
                                    transform_args,
                                    task_sequence,
                                    args.batch_size,
                                    frontal_lateral=model_args.frontal_lateral,
                                    covar_list=model_args.covar_list,
                                    fold_num=data_args.fold_num)
    class_weights = train_loader.dataset.class_weights

    # Get loss functions
    uw_loss_fn = get_loss_fn(args.loss_fn, args.device, model_args.model_uncertainty,
        args.has_tasks_missing, class_weights=class_weights)
    w_loss_fn = get_loss_fn('weighted_loss', args.device, model_args.model_uncertainty,
        args.has_tasks_missing, class_weights=class_weights)

    # Get logger, evaluator and saver
    logger = TrainLogger(logger_args, args.start_epoch, args.num_epochs, args.batch_size,
        len(train_loader.dataset), args.device, normalization=transform_args.normalization)
    
    eval_args = {}
    eval_args['num_visuals'] = logger_args.num_visuals
    eval_args['iters_per_eval'] = logger_args.iters_per_eval
    eval_args['has_missing_tasks'] = args.has_tasks_missing
    eval_args['model_uncertainty'] = model_args.model_uncertainty
    eval_args['class_weights'] = class_weights
    eval_args['max_eval'] = logger_args.max_eval
    eval_args['device'] = args.device
    eval_args['optimizer'] = optimizer
    evaluator = get_evaluator('classification', eval_loaders, logger, eval_args)

    print("Eval Loaders: %d" % len(eval_loaders))
    saver = ModelSaver(**vars(logger_args))

    metrics = None
    lr_step = 0
    # Train model
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets, info_dict, covars in train_loader:
            logger.start_iter()

            # Evaluate and save periodically
            metrics, curves = evaluator.evaluate(model, args.device, logger.global_step)
            logger.plot_metrics(metrics)
            metric_val = metrics.get(logger_args.metric_name, None)
            assert logger.global_step % logger_args.iters_per_eval != 0 or metric_val is not None
            saver.save(logger.global_step, logger.epoch, model, optimizer, lr_scheduler, args.device,
                       metric_val=metric_val, covar_list=model_args.covar_list)
            lr_step = util.step_scheduler(lr_scheduler, metrics, lr_step, best_ckpt_metric=logger_args.metric_name)

            # Input: [batch_size, channels, width, height]

            with torch.set_grad_enabled(True):
            # with torch.autograd.set_detect_anomaly(True):

                logits = model.forward([inputs.to(args.device), covars])

                # Scale up TB so that it's loss is counted for more if upweight_tb is True.
                if model_args.upweight_tb is True:
                    tb_targets = targets.narrow(1, 0, 1)
                    findings_targets = targets.narrow(1, 1, targets.shape[1] - 1)
                    tb_targets = tb_targets.repeat(1, targets.shape[1] - 1)
                    new_targets = torch.cat((tb_targets, findings_targets), 1)

                    tb_logits = logits.narrow(1, 0, 1)
                    findings_logits = logits.narrow(1, 1, logits.shape[1] - 1)
                    tb_logits = tb_logits.repeat(1, logits.shape[1] - 1)
                    new_logits = torch.cat((tb_logits, findings_logits), 1)
                else:
                    new_logits = logits
                    new_targets = targets

                    
                unweighted_loss = uw_loss_fn(new_logits, new_targets.to(args.device))

                weighted_loss = w_loss_fn(logits, targets.to(args.device)) if w_loss_fn else None

                logger.log_iter(inputs, logits, targets, unweighted_loss, weighted_loss, optimizer)

                optimizer.zero_grad()
                if args.loss_fn == 'weighted_loss':
                    weighted_loss.backward()
                else:
                    unweighted_loss.backward()
                optimizer.step()

            logger.end_iter()

        logger.end_epoch(metrics, optimizer)