Example #1
0
    def backward_G_B(self):

        self.loss_G_adversarial_B = loss.adversarial_loss_generator(
            self.fakeBpred,
            self.outputApred,
            method='L2',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_reconstruction_B = loss.reconstruction_loss(
            self.outputB,
            self.realB,
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_mask_B = loss.mask_loss(
            self.maskB,
            threshold=self.loss_config['mask_threshold'],
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_B = self.loss_G_adversarial_B + self.loss_G_reconstruction_B + self.loss_G_mask_B

        if self.loss_config['pl_on']:
            self.loss_G_perceptual_B = loss.perceptual_loss(
                self.realB,
                self.fakeB,
                self.vggface,
                self.vggface_for_pl,
                method='L2',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_B += self.loss_G_perceptual_B

        if self.loss_config['edgeloss_on']:
            self.loss_G_edge_B = loss.edge_loss(
                self.outputB,
                self.realB,
                self.mask_eye_B,
                method='L1',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_B += self.loss_G_edge_B

        if self.loss_config['eyeloss_on']:
            self.loss_G_eye_B = loss.eye_loss(
                self.outputB,
                self.realB,
                self.mask_eye_B,
                method='L1',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_B += self.loss_G_eye_B

        self.loss_G_B.backward(retain_graph=True)
Example #2
0
    def backward_G_A(self):

        self.loss_G_adversarial_A = loss.adversarial_loss_generator(
            self.fakeApred,
            self.outputApred,
            method='L2',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_reconstruction_A = loss.reconstruction_loss(
            self.outputA,
            self.realA,
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_mask_A = loss.mask_loss(
            self.maskA,
            threshold=self.loss_config['mask_threshold'],
            method='L1',
            loss_weight_config=self.loss_weight_config)

        self.loss_G_A = self.loss_G_adversarial_A + self.loss_G_reconstruction_A + self.loss_G_mask_A

        if self.loss_config['pl_on']:
            self.loss_G_perceptual_A = loss.perceptual_loss(
                self.realA,
                self.fakeA,
                self.vggface,
                self.vggface_for_pl,
                method='L2',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_A += self.loss_G_perceptual_A

        if self.loss_config['edgeloss_on']:
            self.loss_G_edge_A = loss.edge_loss(
                self.outputA,
                self.realA,
                self.mask_eye_A,
                method='L1',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_A += self.loss_G_edge_A

        if self.loss_config['eyeloss_on']:
            self.loss_G_eye_A = loss.eye_loss(
                self.outputA,
                self.realA,
                self.mask_eye_A,
                method='L1',
                loss_weight_config=self.loss_weight_config)
            self.loss_G_A += self.loss_G_eye_A

        self.loss_G_A.backward(retain_graph=True)
    def forward(self, content, style, alpha=1.0):

        style_feats = self.encode_with_intermediate(style)
        cont_feats = self.encode_with_intermediate(content)

        hidden_cont_feats = self.feature_pyramid(cont_feats[-3:])
        hidden_style_feats = self.feature_pyramid(style_feats[-3:])

        cs, cs_feats = self.pair_inference(cont_feats, style_feats,
                                           hidden_cont_feats,
                                           hidden_style_feats)
        if not self.training:
            return cs

        # perceptual
        loss_c = loss.perceptual_loss(cs_feats[-3:], cont_feats[-3:])

        # Style Loss
        loss_s = loss.adain_style_loss(cs_feats, style_feats)

        result = (cs, loss_c, loss_s)

        if self.use_iden:
            cc, cc_feats = self.pair_inference(cont_feats, cont_feats,
                                               hidden_cont_feats,
                                               hidden_cont_feats, True)
            ss, ss_feats = self.pair_inference(style_feats, style_feats,
                                               hidden_style_feats,
                                               hidden_style_feats, True)
            loss_i = loss.identity_loss(cc, cc_feats, content, cont_feats, ss,
                                        ss_feats, style, style_feats, 50)
            result += (loss_i, )
        else:
            result += (0, )
        if self.use_cx:
            loss_cx = loss.contextual_loss(cs_feats, style_feats)
            result += (loss_cx, )
        else:
            result += (0, )
        result += (loss.total_variation(cs), )
        return result
Example #4
0
def train(cfg):
    # Set device if gpu is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Build network
    net = ImageTransformationNet().to(device)

    # Setup optimizer
    optimizer = optim.Adam(net.parameters())

    # Load state if resuming training
    if cfg['resume']:
        checkpoint = torch.load(cfg['resume'])
        net.load_state_dict(checkpoint['net_state_dict'])
        optimizer.load_state_dict(checkpoint['opt_state_dict'])

        # Get starting epoch and batch (expects weight file in form EPOCH_<>_BATCH_<>.pt)
        parts = cfg['resume'].split('_')
        first_epoch = int(checkpoint['epoch'])
        first_batch = int(parts[-1].split('.')[0])

        # Setup dataloader
        train_data = tqdm(build_data_loader(cfg), initial=first_batch)

    else:
        # Setup dataloader
        train_data = tqdm(build_data_loader(cfg))

        # Set first epoch and batch
        first_epoch = 1
        first_batch = 0

    # Fetch style image and style grams
    style_im = load_image(cfg['style_image'], cfg)
    style_grams = get_style_grams(style_im, cfg)

    # Setup log file if specified
    log_dir = Path('logs')
    log_dir.mkdir(parents=True, exist_ok=True)
    if cfg['log_file'] and not cfg['resume']:
        today = datetime.datetime.today().strftime('%m/%d/%Y')
        header = f'Feed-Forward Style Transfer Training Log - {today}'
        with open(cfg['log_file'], 'w+') as file:
            file.write(header + '\n\n')

    # Setup log CSV if specified
    if cfg['csv_log_file'] and not cfg['resume']:
        utils.setup_csv(cfg)

    for epoch in range(first_epoch, cfg['epochs'] + 1):

        # Keep track of per epoch loss
        content_loss = 0
        style_loss = 0
        total_var_loss = 0
        train_loss = 0
        num_batches = 0

        # Setup first batch to start enumerate at proper place
        if epoch == first_epoch:
            start = first_batch
        else:
            start = 0

        for i, batch in enumerate(train_data, start=start):
            batch = batch.to(device)

            # Put batch through network
            batch_styled = net(batch)

            # Get vgg activations for styled and unstyled batch
            features = vgg_activations(batch_styled)
            content_features = vgg_activations(batch)

            # Get loss
            c_loss, s_loss = perceptual_loss(features=features,
                                             content_features=content_features,
                                             style_grams=style_grams,
                                             cfg=cfg)
            tv_loss = total_variation_loss(batch_styled, cfg)
            total_loss = c_loss + s_loss + tv_loss

            # Backpropogate
            total_loss.backward()

            # Do one step of optimization
            optimizer.step()

            # Clear gradients before next batch
            optimizer.zero_grad()

            # Update summary statistics
            with torch.no_grad():
                content_loss += c_loss.item()
                style_loss += s_loss.item()
                total_var_loss += tv_loss.item()
                train_loss += total_loss.item()
                num_batches += 1

            # Update progress bar
            avg_loss = round(train_loss / num_batches, 2)
            avg_c_loss = round(content_loss / num_batches, 2)
            avg_s_loss = round(style_loss / num_batches, 1)
            avg_tv_loss = round(total_var_loss / num_batches, 3)
            train_data.set_description(
                f'C - {avg_c_loss} | S - {avg_s_loss} | TV - {avg_tv_loss} | Total - {avg_loss}'
            )
            train_data.refresh()

            # Create progress image if specified
            if cfg['image_checkpoint'] and ((i + 1) % cfg['image_checkpoint']
                                            == 0):
                save_path = str(
                    Path(
                        cfg['image_checkpoint_dir'],
                        f'EPOCH_{str(epoch).zfill(3)}_BATCH_{str(i+1).zfill(5)}.png'
                    ))
                utils.make_checkpoint_image(cfg, net, save_path)

            # Save weights if specified
            if cfg['save_checkpoint'] and ((i + 1) % cfg['save_checkpoint']
                                           == 0):
                save_path = str(
                    Path(
                        cfg['save_checkpoint_dir'],
                        f'EPOCH_{str(epoch).zfill(3)}_BATCH_{str(i+1).zfill(5)}.pth'
                    ))
                checkpoint = {
                    'epoch': epoch,
                    'net_state_dict': net.state_dict(),
                    'opt_state_dict': optimizer.state_dict(),
                    'loss': avg_loss
                }
                torch.save(checkpoint, save_path)

            # Write progress row to CSV
            if cfg['csv_checkpoint'] and ((i + 1) % cfg['csv_checkpoint']
                                          == 0):
                row = [
                    epoch, i + 1, avg_c_loss, avg_s_loss, avg_tv_loss, avg_loss
                ]
                utils.write_progress_row(cfg, row)

        # Write loss at end of each epoch
        if cfg['log_file']:
            avg_loss = round(train_loss / num_batches, 4)
            line = f'EPOCH {epoch} | Loss - {avg_loss}'
            with open(cfg['log_file'], 'a') as file:
                file.write(line + '\n')

        # Save network if specified
        if cfg['epoch_save_checkpoint'] and (
                epoch % cfg['epoch_save_checkpoint'] == 0):
            save_path = str(
                Path(cfg['save_checkpoint_dir'],
                     f'EPOCH_{str(epoch).zfill(3)}.pth'))
            checkpoint = {
                'epoch': epoch,
                'net_state_dict': net.state_dict(),
                'opt_state_dict': optimizer.state_dict(),
                'loss': round(train_loss / num_batches, 4)
            }
            torch.save(checkpoint, save_path)
Example #5
0
def train(args):

    # get context

    ctx = get_extension_context(args.context)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    config = read_yaml(args.config)

    if args.info:
        config.monitor_params.info = args.info

    if comm.size == 1:
        comm = None
    else:
        # disable outputs from logger except its rank = 0
        if comm.rank > 0:
            import logging
            logger.setLevel(logging.ERROR)

    test = False
    train_params = config.train_params
    dataset_params = config.dataset_params
    model_params = config.model_params

    loss_flags = get_loss_flags(train_params)

    start_epoch = 0

    rng = np.random.RandomState(device_id)
    data_iterator = frame_data_iterator(
        root_dir=dataset_params.root_dir,
        frame_shape=dataset_params.frame_shape,
        id_sampling=dataset_params.id_sampling,
        is_train=True,
        random_seed=rng,
        augmentation_params=dataset_params.augmentation_params,
        batch_size=train_params['batch_size'],
        shuffle=True,
        with_memory_cache=False,
        with_file_cache=False)

    if n_devices > 1:
        data_iterator = data_iterator.slice(rng=rng,
                                            num_of_slices=comm.size,
                                            slice_pos=comm.rank)
        # workaround not to use memory cache
        data_iterator._data_source._on_memory = False
        logger.info("Disabled on memory data cache.")

    bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    with nn.parameter_scope("kp_detector"):
        # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))}

        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=test,
                                    comm=comm)
        persistent_all(kp_source)

        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=test,
                                     comm=comm)
        persistent_all(kp_driving)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=kp_source,
                                              kp_driving=kp_driving,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=test,
                                              comm=comm)
        # generated is a dictionary containing;
        # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
        # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4))
        # 'occlusion_map': Variable((bs, 1, h/4, w/4))
        # 'deformed': Variable((bs, c, h, w))
        # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator.

    generated["prediction"].persistent = True

    pyramide_real = get_image_pyramid(driving, train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_real)

    pyramide_fake = get_image_pyramid(generated['prediction'],
                                      train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_fake)

    total_loss_G = None  # dammy. defined temporarily
    loss_var_dict = {}

    # perceptual loss using VGG19 (always applied)
    if loss_flags.use_perceptual_loss:
        logger.info("Use Perceptual Loss.")
        scales = train_params.scales
        weights = train_params.loss_weights.perceptual
        vgg_param_path = train_params.vgg_param_path
        percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales,
                                      weights, vgg_param_path)
        percep_loss.persistent = True
        loss_var_dict['perceptual_loss'] = percep_loss
        total_loss_G = percep_loss

    # (LS)GAN loss and feature matching loss
    if loss_flags.use_gan_loss:
        logger.info("Use GAN Loss.")
        with nn.parameter_scope("discriminator"):
            discriminator_maps_generated = multiscale_discriminator(
                pyramide_fake,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

            discriminator_maps_real = multiscale_discriminator(
                pyramide_real,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

        for v in discriminator_maps_generated["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_generated["prediction_map_1"].persistent = True

        for v in discriminator_maps_real["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_real["prediction_map_1"].persistent = True

        for i, scale in enumerate(model_params.discriminator_params.scales):
            key = f'prediction_map_{scale}'.replace('.', '-')
            lsgan_loss_weight = train_params.loss_weights.generator_gan
            # LSGAN loss for Generator
            if i == 0:
                gan_loss_gen = lsgan_loss(discriminator_maps_generated[key],
                                          lsgan_loss_weight)
            else:
                gan_loss_gen += lsgan_loss(discriminator_maps_generated[key],
                                           lsgan_loss_weight)
            # LSGAN loss for Discriminator
            if i == 0:
                gan_loss_dis = lsgan_loss(discriminator_maps_real[key],
                                          lsgan_loss_weight,
                                          discriminator_maps_generated[key])
            else:
                gan_loss_dis += lsgan_loss(discriminator_maps_real[key],
                                           lsgan_loss_weight,
                                           discriminator_maps_generated[key])
        gan_loss_dis.persistent = True
        loss_var_dict['gan_loss_dis'] = gan_loss_dis
        total_loss_D = gan_loss_dis
        total_loss_D.persistent = True

        gan_loss_gen.persistent = True
        loss_var_dict['gan_loss_gen'] = gan_loss_gen
        total_loss_G += gan_loss_gen

        if loss_flags.use_feature_matching_loss:
            logger.info("Use Feature Matching Loss.")
            fm_weights = train_params.loss_weights.feature_matching
            fm_loss = feature_matching_loss(discriminator_maps_real,
                                            discriminator_maps_generated,
                                            model_params, fm_weights)
            fm_loss.persistent = True
            loss_var_dict['feature_matching_loss'] = fm_loss
            total_loss_G += fm_loss

    # transform loss
    if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss:
        transform = Transform(bs, **config.train_params.transform_params)
        transformed_frame = transform.transform_frame(driving)

        with nn.parameter_scope("kp_detector"):
            transformed_kp = detect_keypoint(transformed_frame,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=test,
                                             comm=comm)
        persistent_all(transformed_kp)

        # Value loss part
        if loss_flags.use_equivariance_value_loss:
            logger.info("Use Equivariance Value Loss.")
            warped_kp_value = transform.warp_coordinates(
                transformed_kp['value'])
            eq_value_weight = train_params.loss_weights.equivariance_value

            eq_value_loss = equivariance_value_loss(kp_driving['value'],
                                                    warped_kp_value,
                                                    eq_value_weight)
            eq_value_loss.persistent = True
            loss_var_dict['equivariance_value_loss'] = eq_value_loss
            total_loss_G += eq_value_loss

        # jacobian loss part
        if loss_flags.use_equivariance_jacobian_loss:
            logger.info("Use Equivariance Jacobian Loss.")
            arithmetic_jacobian = transform.jacobian(transformed_kp['value'])
            eq_jac_weight = train_params.loss_weights.equivariance_jacobian
            eq_jac_loss = equivariance_jacobian_loss(
                kp_driving['jacobian'], arithmetic_jacobian,
                transformed_kp['jacobian'], eq_jac_weight)
            eq_jac_loss.persistent = True
            loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss
            total_loss_G += eq_jac_loss

    assert total_loss_G is not None
    total_loss_G.persistent = True
    loss_var_dict['total_loss_gen'] = total_loss_G

    # -------------------- Create Monitors --------------------
    monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors(
        config, loss_flags, loss_var_dict)

    if device_id == 0:
        # Dump training info .yaml
        _ = shutil.copy(args.config, log_dir)  # copy the config yaml
        training_info_yaml = os.path.join(log_dir, "training_info.yaml")
        os.rename(os.path.join(log_dir, os.path.basename(args.config)),
                  training_info_yaml)
        # then add additional information
        with open(training_info_yaml, "a", encoding="utf-8") as f:
            f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None")

    # -------------------- Solver Setup --------------------
    solvers = setup_solvers(train_params)
    solver_generator = solvers["generator"]
    solver_discriminator = solvers["discriminator"]
    solver_kp_detector = solvers["kp_detector"]

    # max epochs
    num_epochs = train_params['num_epochs']

    # iteration per epoch
    num_iter_per_epoch = data_iterator.size // bs
    # will be increased by num_repeat
    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        num_iter_per_epoch *= config.train_params.num_repeats

    # modify learning rate if current epoch exceeds the number defined in
    lr_decay_at_epochs = train_params['epoch_milestones']  # ex. [60, 90]
    gamma = 0.1  # decay rate

    # -------------------- For finetuning ---------------------
    if args.ft_params:
        assert os.path.isfile(args.ft_params)
        logger.info(f"load {args.ft_params} for finetuning.")
        nn.load_parameters(args.ft_params)
        start_epoch = int(
            os.path.splitext(os.path.basename(
                args.ft_params))[0].split("epoch_")[1])

        # set solver's state
        for name, solver in solvers.items():
            saved_states = os.path.join(
                os.path.dirname(args.ft_params),
                f"state_{name}_at_epoch_{start_epoch}.h5")
            solver.load_states(saved_states)

        start_epoch += 1
        logger.info(f"Resuming from epoch {start_epoch}.")

    logger.info(
        f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch."
    )

    for e in range(start_epoch, num_epochs):
        logger.info(f"Epoch: {e} / {num_epochs}.")
        data_iterator._reset()  # rewind the iterator at the beginning

        # learning rate scheduler
        if e in lr_decay_at_epochs:
            logger.info("Learning rate decayed.")
            learning_rate_decay(solvers, gamma=gamma)

        for i in range(num_iter_per_epoch):
            _driving, _source = data_iterator.next()
            source.d = _source
            driving.d = _driving

            # update generator and keypoint detector
            total_loss_G.forward()

            if device_id == 0:
                monitors_gen.add((e * num_iter_per_epoch + i) * n_devices)

            solver_generator.zero_grad()
            solver_kp_detector.zero_grad()

            callback = None
            if n_devices > 1:
                params = [x.grad for x in solver_generator.get_parameters().values()] + \
                         [x.grad for x in solver_kp_detector.get_parameters().values()]
                callback = comm.all_reduce_callback(params, 2 << 20)
            total_loss_G.backward(clear_buffer=True,
                                  communicator_callbacks=callback)

            solver_generator.update()
            solver_kp_detector.update()

            if loss_flags.use_gan_loss:
                # update discriminator

                total_loss_D.forward(clear_no_need_grad=True)
                if device_id == 0:
                    monitors_dis.add((e * num_iter_per_epoch + i) * n_devices)

                solver_discriminator.zero_grad()

                callback = None
                if n_devices > 1:
                    params = [
                        x.grad for x in
                        solver_discriminator.get_parameters().values()
                    ]
                    callback = comm.all_reduce_callback(params, 2 << 20)
                total_loss_D.backward(clear_buffer=True,
                                      communicator_callbacks=callback)

                solver_discriminator.update()

            if device_id == 0:
                monitor_time.add((e * num_iter_per_epoch + i) * n_devices)

            if device_id == 0 and (
                (e * num_iter_per_epoch + i) *
                    n_devices) % config.monitor_params.visualize_freq == 0:
                images_to_visualize = [
                    source.d, driving.d, generated["prediction"].d
                ]
                visuals = combine_images(images_to_visualize)
                monitor_vis.add((e * num_iter_per_epoch + i) * n_devices,
                                visuals)

        if device_id == 0:
            if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1:
                save_parameters(e, log_dir, solvers)

    return
Example #6
0
def main():
    global params, best_iou, num_iter, tb_writer, logger, logger_results
    best_iou = 0
    params = Params()
    params.save_params('{:s}/params.txt'.format(params.paths['save_dir']))
    tb_writer = SummaryWriter('{:s}/tb_logs'.format(params.paths['save_dir']))

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in params.train['gpu'])

    # set up logger
    logger, logger_results = setup_logging(params)

    # ----- create model ----- #
    model_name = params.model['name']
    if model_name == 'ResUNet34':
        model = ResUNet34(params.model['out_c'],
                          fixed_feature=params.model['fix_params'])
    elif params.model['name'] == 'UNet':
        model = UNet(3, params.model['out_c'])
    else:
        raise NotImplementedError()

    logger.info('Model: {:s}'.format(model_name))
    # if not params.train['checkpoint']:
    #     logger.info(model)
    model = nn.DataParallel(model)
    model = model.cuda()
    global vgg_model
    logger.info('=> Using VGG16 for perceptual loss...')
    vgg_model = vgg16_feat()
    vgg_model = nn.DataParallel(vgg_model).cuda()
    cudnn.benchmark = True

    # ----- define optimizer ----- #
    optimizer = torch.optim.Adam(model.parameters(),
                                 params.train['lr'],
                                 betas=(0.9, 0.99),
                                 weight_decay=params.train['weight_decay'])

    # ----- get pixel weights and define criterion ----- #
    if not params.train['weight_map']:
        criterion = torch.nn.NLLLoss().cuda()
    else:
        logger.info('=> Using weight maps...')
        criterion = torch.nn.NLLLoss(reduction='none').cuda()

    if params.train['beta'] > 0:
        logger.info('=> Using perceptual loss...')
        global criterion_perceptual
        criterion_perceptual = perceptual_loss()

    data_transforms = {
        'train': get_transforms(params.transform['train']),
        'val': get_transforms(params.transform['val'])
    }

    # ----- load data ----- #
    dsets = {}
    for x in ['train', 'val']:
        img_dir = '{:s}/{:s}'.format(params.paths['img_dir'], x)
        target_dir = '{:s}/{:s}'.format(params.paths['label_dir'], x)
        if params.train['weight_map']:
            weight_map_dir = '{:s}/{:s}'.format(params.paths['weight_map_dir'],
                                                x)
            dir_list = [img_dir, weight_map_dir, target_dir]
            postfix = ['weight.png', 'label_with_contours.png']
            num_channels = [3, 1, 3]
        else:
            dir_list = [img_dir, target_dir]
            postfix = ['label_with_contours.png']
            num_channels = [3, 3]
        dsets[x] = DataFolder(dir_list, postfix, num_channels,
                              data_transforms[x])
    train_loader = DataLoader(dsets['train'],
                              batch_size=params.train['batch_size'],
                              shuffle=True,
                              num_workers=params.train['workers'])
    val_loader = DataLoader(dsets['val'],
                            batch_size=params.train['val_batch_size'],
                            shuffle=False,
                            num_workers=params.train['workers'])

    # ----- optionally load from a checkpoint for validation or resuming training ----- #
    if params.train['checkpoint']:
        if os.path.isfile(params.train['checkpoint']):
            logger.info("=> loading checkpoint '{}'".format(
                params.train['checkpoint']))
            checkpoint = torch.load(params.train['checkpoint'])
            params.train['start_epoch'] = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                params.train['checkpoint'], checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(
                params.train['checkpoint']))

    # ----- training and validation ----- #
    num_iter = params.train['num_epochs'] * len(train_loader)

    # print training parameters
    logger.info("=> Initial learning rate: {:g}".format(params.train['lr']))
    logger.info("=> Batch size: {:d}".format(params.train['batch_size']))
    # logger.info("=> Number of training iterations: {:d}".format(num_iter))
    logger.info("=> Training epochs: {:d}".format(params.train['num_epochs']))
    logger.info("=> beta: {:.1f}".format(params.train['beta']))

    for epoch in range(params.train['start_epoch'],
                       params.train['num_epochs']):
        # train for one epoch or len(train_loader) iterations
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1,
                                                params.train['num_epochs']))
        train_results = train(train_loader, model, optimizer, criterion, epoch)
        train_loss, train_loss_ce, train_loss_var, train_iou_nuclei, train_iou = train_results

        # evaluate on validation set
        with torch.no_grad():
            val_results = validate(val_loader, model, criterion)
            val_loss, val_loss_ce, val_loss_var, val_iou_nuclei, val_iou = val_results

        # check if it is the best accuracy
        combined_iou = (val_iou_nuclei + val_iou) / 2
        is_best = combined_iou > best_iou
        best_iou = max(combined_iou, best_iou)

        cp_flag = (epoch + 1) % params.train['checkpoint_freq'] == 0

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict(),
            }, epoch, is_best, params.paths['save_dir'], cp_flag)

        # save the training results to txt files
        logger_results.info(
            '{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'
            .format(epoch + 1, train_loss, train_loss_ce, train_loss_var,
                    train_iou_nuclei, train_iou, val_loss, val_iou_nuclei,
                    val_iou))
        # tensorboard logs
        tb_writer.add_scalars(
            'epoch_losses', {
                'train_loss': train_loss,
                'train_loss_ce': train_loss_ce,
                'train_loss_var': train_loss_var,
                'val_loss': val_loss
            }, epoch)
        tb_writer.add_scalars(
            'epoch_accuracies', {
                'train_iou_nuclei': train_iou_nuclei,
                'train_iou': train_iou,
                'val_iou_nuclei': val_iou_nuclei,
                'val_iou': val_iou
            }, epoch)
    tb_writer.close()