Exemple #1
0
def main():
    parser = argparse.ArgumentParser(description="Baseline Experiment Eval")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    # Seeding
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # This can slow down training

    # load the data
    test_loader = torch.utils.data.DataLoader(get_dataset(cfg, 'test'),
                                              batch_size=cfg.TEST.BATCH_SIZE,
                                              shuffle=False,
                                              pin_memory=True)

    task1, task2 = get_tasks(cfg)
    model = get_model(cfg, task1, task2)

    ckpt_path = os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                             'ckpt-%s.pth' % str(cfg.TEST.CKPT_ID).zfill(5))
    print("Evaluating Checkpoint at %s" % ckpt_path)
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    if cfg.CUDA:
        model = model.cuda()

    task1_metric, task2_metric = evaluate(test_loader, model, task1, task2)
    for k, v in task1_metric.items():
        print('{}: {:.3f}'.format(k, v))
    for k, v in task2_metric.items():
        print('{}: {:.3f}'.format(k, v))
Exemple #2
0
def main():
    root = '/home/yel/yel/data/Aerialgoaf/detail/'
    # root = '/home/yel/yel/data/DeepCrack-master/dataset/Deepcrack/'
    img_dir = root + 'train'
    label_dir = root + 'trainannot'
    val_dir = root + 'val'
    vallabel_dir = root + 'valannot'
    train_ds = get_dataset(img_dir, label_dir, batch_size=5)
    val_ds = get_dataset(val_dir, vallabel_dir, batch_size=5)
    model = MSI_FCN()

    lr = tf.keras.optimizers.schedules.ExponentialDecay(2e-4, 10000, 0.1)
    optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)

    fit(train_ds=train_ds,
        val_ds=val_ds,
        model=model,
        optimizer=optimizer,
        loss_func=WSCE,
        work_dir='../work_dir/msi_fcn_2',
        epochs=100,
        fine_tune=True)
Exemple #3
0
import tensorflow as tf
import time
import os
from model.msi_fcn import MSI_FCN
from core.data import get_dataset
from core.loss import WSCE
from core.metrics import show_metrics
import datetime

root = '/home/yel/yel/data/Aerialgoaf/detail/'
img_dir = root + 'train'
label_dir = root + 'trainannot'
train_ds = get_dataset(img_dir, label_dir, batch_size=3)
model = MSI_FCN()
optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint_dir = './training_checkpoints'
# checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

log_dir = './logs/'
summary_writer = tf.summary.create_file_writer(
    log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))


# @tf.function
def train_step(model, input, label, loss_object, optimizer, show_metrics, summary_writer, step):
    with tf.GradientTape() as t:
        output = model(input, training=True)
        loss = loss_object(output, label)
Exemple #4
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch MTLNAS Eval")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--port", type=int, default=29502)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # Preparing for DDP training
    logging = args.local_rank == 0
    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = str(args.port)
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    # Adjust batch size for distributed training
    assert cfg.TRAIN.BATCH_SIZE % num_gpus == 0
    cfg.TRAIN.BATCH_SIZE = int(cfg.TRAIN.BATCH_SIZE // num_gpus)
    assert cfg.TEST.BATCH_SIZE % num_gpus == 0
    cfg.TEST.BATCH_SIZE = int(cfg.TEST.BATCH_SIZE // num_gpus)
    cfg.freeze()

    # Seeding
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # This can slow down training

    if not os.path.exists(os.path.join(cfg.SAVE_DIR,
                                       cfg.EXPERIMENT_NAME)) and logging:
        os.makedirs(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME))

    # load the data
    test_data = get_dataset(cfg, 'test')

    if distributed:
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            test_data)
    else:
        test_sampler = None

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=cfg.TEST.BATCH_SIZE,
                                              shuffle=False,
                                              sampler=test_sampler)

    task1, task2 = get_tasks(cfg)
    model = get_model(cfg, task1, task2)

    if cfg.CUDA:
        model = model.cuda()

    ckpt_path = os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                             'ckpt-%s.pth' % str(cfg.TEST.CKPT_ID).zfill(5))
    print("Evaluating Checkpoint at %s" % ckpt_path)
    ckpt = torch.load(ckpt_path)
    # compatibility with ddp saved checkpoint when evaluating without ddp
    pretrain_dict = {
        k.replace('module.', ''): v
        for k, v in ckpt['model_state_dict'].items()
    }
    model_dict = model.state_dict()
    model_dict.update(pretrain_dict)
    model.load_state_dict(model_dict)

    if distributed:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = MyDataParallel(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank,
                               find_unused_parameters=True)

    model.eval()

    task1_metric, task2_metric = evaluate(test_loader, model, task1, task2,
                                          distributed, args.local_rank)
    if logging:
        for k, v in task1_metric.items():
            print('{}: {:.9f}'.format(k, v))
        for k, v in task2_metric.items():
            print('{}: {:.9f}'.format(k, v))
Exemple #5
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch MTLNAS Training")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--port", type=int, default=29501)
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    # Preparing for DDP training
    logging = args.local_rank == 0
    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = str(args.port)
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    # Adjust batch size for distributed training
    assert cfg.TRAIN.BATCH_SIZE % num_gpus == 0
    cfg.TRAIN.BATCH_SIZE = int(cfg.TRAIN.BATCH_SIZE // num_gpus)
    assert cfg.TEST.BATCH_SIZE % num_gpus == 0
    cfg.TEST.BATCH_SIZE = int(cfg.TEST.BATCH_SIZE // num_gpus)
    cfg.freeze()

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d~%H:%M:%S")
    experiment_log_dir = os.path.join(cfg.LOG_DIR, cfg.EXPERIMENT_NAME,
                                      timestamp)
    if not os.path.exists(experiment_log_dir) and logging:
        os.makedirs(experiment_log_dir)
        writer = SummaryWriter(logdir=experiment_log_dir)
    printf = get_print(experiment_log_dir)
    printf("Training with Config: ")
    printf(cfg)

    # Seeding
    os.environ['PYTHONHASHSEED'] = str(cfg.SEED)
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # This can slow down training

    if not os.path.exists(os.path.join(cfg.SAVE_DIR,
                                       cfg.EXPERIMENT_NAME)) and logging:
        os.makedirs(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME))

    # load the data
    train_full_data = get_dataset(cfg, 'train')

    num_train = len(train_full_data)
    indices = list(range(num_train))
    split = int(np.floor(cfg.ARCH.TRAIN_SPLIT * num_train))

    # load the data
    if cfg.TRAIN.EVAL_CKPT:
        test_data = get_dataset(cfg, 'val')

        if distributed:
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_data)
        else:
            test_sampler = None

        test_loader = torch.utils.data.DataLoader(
            test_data,
            batch_size=cfg.TEST.BATCH_SIZE,
            shuffle=False,
            sampler=test_sampler)

    task1, task2 = get_tasks(cfg)
    model = get_model(cfg, task1, task2)

    if cfg.CUDA:
        model = model.cuda()

    if distributed:
        # Important: Double check if BN is working as expected
        if cfg.TRAIN.APEX:
            printf("using apex synced BN")
            model = apex.parallel.convert_syncbn_model(model)
        else:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = MyDataParallel(model,
                               device_ids=[args.local_rank],
                               output_device=args.local_rank,
                               find_unused_parameters=True)

    # hacky way to pick params
    nddr_params = []
    fc8_weights = []
    fc8_bias = []
    base_params = []
    for k, v in model.named_net_parameters():
        if 'paths' in k:
            nddr_params.append(v)
        elif model.net1.fc_id in k:
            if 'weight' in k:
                fc8_weights.append(v)
            else:
                assert 'bias' in k
                fc8_bias.append(v)
        else:
            assert 'alpha' not in k
            base_params.append(v)
    assert len(nddr_params) > 0 and len(fc8_weights) > 0 and len(fc8_bias) > 0

    parameter_dict = [{
        'params': base_params
    }, {
        'params': fc8_weights,
        'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_WEIGHT_FACTOR
    }, {
        'params': fc8_bias,
        'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_BIAS_FACTOR
    }, {
        'params': nddr_params,
        'lr': cfg.TRAIN.LR * cfg.TRAIN.NDDR_FACTOR
    }]
    optimizer = optim.SGD(parameter_dict,
                          lr=cfg.TRAIN.LR,
                          momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    if cfg.ARCH.OPTIMIZER == 'sgd':
        arch_optimizer = torch.optim.SGD(
            model.arch_parameters(),
            lr=cfg.ARCH.LR,
            momentum=cfg.TRAIN.MOMENTUM,  # TODO: separate this param
            weight_decay=cfg.ARCH.WEIGHT_DECAY)
    else:
        arch_optimizer = torch.optim.Adam(model.arch_parameters(),
                                          lr=cfg.ARCH.LR,
                                          betas=(0.5, 0.999),
                                          weight_decay=cfg.ARCH.WEIGHT_DECAY)

    if cfg.TRAIN.SCHEDULE == 'Poly':
        if cfg.TRAIN.WARMUP > 0.:
            scheduler = optim.lr_scheduler.LambdaLR(
                optimizer,
                lambda step: min(1.,
                                 float(step) / cfg.TRAIN.WARMUP) *
                (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER,
                last_epoch=-1)
            arch_scheduler = optim.lr_scheduler.LambdaLR(
                arch_optimizer,
                lambda step: min(1.,
                                 float(step) / cfg.TRAIN.WARMUP) *
                (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER,
                last_epoch=-1)
        else:
            scheduler = optim.lr_scheduler.LambdaLR(
                optimizer,
                lambda step:
                (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER,
                last_epoch=-1)
            arch_scheduler = optim.lr_scheduler.LambdaLR(
                arch_optimizer,
                lambda step:
                (1 - float(step) / cfg.TRAIN.STEPS)**cfg.TRAIN.POWER,
                last_epoch=-1)
    elif cfg.TRAIN.SCHEDULE == 'Cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, cfg.TRAIN.STEPS)
        arch_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            arch_optimizer, cfg.TRAIN.STEPS)
    elif cfg.TRAIN.SCHEDULE == 'Step':
        milestones = (np.array([0.6, 0.9]) * cfg.TRAIN.STEPS).astype('int')
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones,
                                                   gamma=0.1)
        arch_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones,
                                                        gamma=0.1)
    else:
        raise NotImplementedError

    if cfg.TRAIN.APEX:
        model, [arch_optimizer,
                optimizer] = amp.initialize(model, [arch_optimizer, optimizer],
                                            opt_level="O1",
                                            num_losses=2)

    model.train()
    steps = 0
    while steps < cfg.TRAIN.STEPS:
        # Initialize train/val dataloader below this shuffle operation
        # to ensure both arch and weights gets to see all the data,
        # but not at the same time during mixed data training
        if cfg.ARCH.MIXED_DATA:
            np.random.shuffle(indices)

        train_data = torch.utils.data.Subset(train_full_data, indices[:split])
        val_data = torch.utils.data.Subset(train_full_data,
                                           indices[split:num_train])

        if distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_data)
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            train_sampler = None
            val_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_data,
            batch_size=cfg.TRAIN.BATCH_SIZE,
            pin_memory=True,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=cfg.TRAIN.BATCH_SIZE,
            pin_memory=True,
            sampler=val_sampler)

        val_iter = iter(val_loader)

        if distributed:
            train_sampler.set_epoch(steps)  # steps is used to seed RNG
            val_sampler.set_epoch(steps)

        for batch_idx, (image, label_1, label_2) in enumerate(train_loader):
            if cfg.CUDA:
                image, label_1, label_2 = image.cuda(), label_1.cuda(
                ), label_2.cuda()

            # get a random minibatch from the search queue without replacement
            val_batch = next(val_iter, None)
            if val_batch is None:  # val_iter has reached its end
                val_sampler.set_epoch(steps)
                val_iter = iter(val_loader)
                val_batch = next(val_iter)
            image_search, label_1_search, label_2_search = val_batch
            image_search = image_search.cuda()
            label_1_search, label_2_search = label_1_search.cuda(
            ), label_2_search.cuda()

            # setting flag for training arch parameters
            model.arch_train()
            assert model.arch_training
            arch_optimizer.zero_grad()
            arch_result = model.loss(image_search,
                                     (label_1_search, label_2_search))
            arch_loss = arch_result.loss

            # Mixed Precision
            if cfg.TRAIN.APEX:
                with amp.scale_loss(arch_loss, arch_optimizer,
                                    loss_id=0) as scaled_loss:
                    scaled_loss.backward()
            else:
                arch_loss.backward()

            arch_optimizer.step()
            model.arch_eval()

            assert not model.arch_training
            optimizer.zero_grad()

            result = model.loss(image, (label_1, label_2))

            out1, out2 = result.out1, result.out2
            loss1 = result.loss1
            loss2 = result.loss2

            loss = result.loss

            # Mixed Precision
            if cfg.TRAIN.APEX:
                with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()
            if cfg.ARCH.SEARCHSPACE == 'GeneralizedMTLNAS':
                model.step()  # update model temperature
            scheduler.step()
            if cfg.ARCH.OPTIMIZER == 'sgd':
                arch_scheduler.step()

            # Print out the loss periodically.
            if steps % cfg.TRAIN.LOG_INTERVAL == 0 and logging:
                printf(
                    'Train Step: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLoss1: {:.6f}\tLoss2: {:.6f}'
                    .format(steps, batch_idx * len(image),
                            len(train_loader.dataset),
                            100. * batch_idx / len(train_loader),
                            loss.data.item(), loss1.data.item(),
                            loss2.data.item()))

                # Log to tensorboard
                writer.add_scalar('lr', scheduler.get_lr()[0], steps)
                writer.add_scalar('arch_lr', arch_scheduler.get_lr()[0], steps)
                writer.add_scalar('loss/overall', loss.data.item(), steps)
                writer.add_image(
                    'image', process_image(image[0],
                                           train_full_data.image_mean), steps)
                task1.log_visualize(out1, label_1, loss1, writer, steps)
                task2.log_visualize(out2, label_2, loss2, writer, steps)

                if cfg.ARCH.ENTROPY_REGULARIZATION:
                    writer.add_scalar('loss/entropy_weight',
                                      arch_result.entropy_weight, steps)
                    writer.add_scalar('loss/entropy_loss',
                                      arch_result.entropy_loss.data.item(),
                                      steps)

                if cfg.ARCH.L1_REGULARIZATION:
                    writer.add_scalar('loss/l1_weight', arch_result.l1_weight,
                                      steps)
                    writer.add_scalar('loss/l1_loss',
                                      arch_result.l1_loss.data.item(), steps)

                if cfg.ARCH.SEARCHSPACE == 'GeneralizedMTLNAS':
                    writer.add_scalar('temperature', model.get_temperature(),
                                      steps)
                    alpha1 = torch.sigmoid(
                        model.net1_alphas).detach().cpu().numpy()
                    alpha2 = torch.sigmoid(
                        model.net2_alphas).detach().cpu().numpy()
                    alpha1_path = os.path.join(experiment_log_dir, 'alpha1')
                    if not os.path.isdir(alpha1_path):
                        os.makedirs(alpha1_path)
                    alpha2_path = os.path.join(experiment_log_dir, 'alpha2')
                    if not os.path.isdir(alpha2_path):
                        os.makedirs(alpha2_path)
                    heatmap1 = save_heatmap(
                        alpha1,
                        os.path.join(alpha1_path,
                                     "%s_alpha1.png" % str(steps).zfill(5)))
                    heatmap2 = save_heatmap(
                        alpha2,
                        os.path.join(alpha2_path,
                                     "%s_alpha2.png" % str(steps).zfill(5)))
                    writer.add_image('alpha/net1', heatmap1, steps)
                    writer.add_image('alpha/net2', heatmap2, steps)
                    network_path = os.path.join(experiment_log_dir, 'network')
                    if not os.path.isdir(network_path):
                        os.makedirs(network_path)
                    connectivity_plot = save_connectivity(
                        alpha1, alpha2, model.net1_connectivity_matrix,
                        model.net2_connectivity_matrix,
                        os.path.join(network_path,
                                     "%s_network.png" % str(steps).zfill(5)))
                    writer.add_image('network', connectivity_plot, steps)

            if steps % cfg.TRAIN.EVAL_INTERVAL == 0:
                if distributed:
                    state_dict = model.module.state_dict()
                else:
                    state_dict = model.state_dict()

                checkpoint = {
                    'cfg': cfg,
                    'step': steps,
                    'model_state_dict': state_dict,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': loss,
                    'loss1': loss1,
                    'loss2': loss2,
                    'task1_metric': None,
                    'task2_metric': None,
                }

                if cfg.TRAIN.EVAL_CKPT:
                    model.eval()
                    torch.cuda.empty_cache()  # TODO check if it helps
                    task1_metric, task2_metric = evaluate(
                        test_loader, model, task1, task2, distributed,
                        args.local_rank)

                    if logging:
                        for k, v in task1_metric.items():
                            writer.add_scalar('eval/{}'.format(k), v, steps)
                        for k, v in task2_metric.items():
                            writer.add_scalar('eval/{}'.format(k), v, steps)
                        for k, v in task1_metric.items():
                            printf('{}: {:.3f}'.format(k, v))
                        for k, v in task2_metric.items():
                            printf('{}: {:.3f}'.format(k, v))

                    checkpoint['task1_metric'] = task1_metric
                    checkpoint['task2_metric'] = task2_metric
                    model.train()
                    torch.cuda.empty_cache()  # TODO check if it helps

                if logging and steps % cfg.TRAIN.SAVE_INTERVAL == 0:
                    torch.save(
                        checkpoint,
                        os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                                     'ckpt-%s.pth' % str(steps).zfill(5)))

            if steps >= cfg.TRAIN.STEPS:
                break
            steps += 1  # train for one extra iteration to allow time for tensorboard logging..
Exemple #6
0
def main():
    parser = argparse.ArgumentParser(description="Baseline Experiment Training")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)
    cfg.EXPERIMENT_NAME = args.config_file.split('/')[-1][:-5]
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d~%H:%M:%S")
    experiment_log_dir = os.path.join(cfg.LOG_DIR, cfg.EXPERIMENT_NAME, timestamp)
    if not os.path.exists(experiment_log_dir):
        os.makedirs(experiment_log_dir)
    writer = SummaryWriter(logdir=experiment_log_dir)
    printf = get_print(experiment_log_dir)
    printf("Training with Config: ")
    printf(cfg)

    # Seeding
    random.seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False  # This can slow down training

    if not os.path.exists(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME)):
        os.makedirs(os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME))

    # load the data
    train_data = get_dataset(cfg, 'train')
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, pin_memory=True)

    # load the data
    if cfg.TRAIN.EVAL_CKPT:
        test_loader = torch.utils.data.DataLoader(
            get_dataset(cfg, 'val'),
            batch_size=cfg.TEST.BATCH_SIZE, shuffle=False, pin_memory=True)

    task1, task2 = get_tasks(cfg)
    model = get_model(cfg, task1, task2)
    
    if cfg.CUDA:
        model = model.cuda()

    # hacky way to pick params
    nddr_params = []
    fc8_weights = []
    fc8_bias = []
    base_params = []
    for k, v in model.named_parameters():
        if 'nddrs' in k:
            nddr_params.append(v)
        elif model.net1.fc_id in k:
            if 'weight' in k:
                fc8_weights.append(v)
            else:
                assert 'bias' in k
                fc8_bias.append(v)
        else:
            base_params.append(v)
    
    if not cfg.MODEL.SINGLETASK and not cfg.MODEL.SHAREDFEATURE:
        assert len(nddr_params) > 0 and len(fc8_weights) > 0 and len(fc8_bias) > 0

    parameter_dict = [
        {'params': fc8_weights, 'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_WEIGHT_FACTOR},
        {'params': fc8_bias, 'lr': cfg.TRAIN.LR * cfg.TRAIN.FC8_BIAS_FACTOR},
        {'params': nddr_params, 'lr': cfg.TRAIN.LR * cfg.TRAIN.NDDR_FACTOR}
    ]
    
    if not cfg.TRAIN.FREEZE_BASE:
        parameter_dict.append({'params': base_params})
    else:
        printf("Frozen net weights")
        
    optimizer = optim.SGD(parameter_dict, lr=cfg.TRAIN.LR, momentum=cfg.TRAIN.MOMENTUM,
                          weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    if cfg.TRAIN.SCHEDULE == 'Poly':
        if cfg.TRAIN.WARMUP > 0.:
            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lambda step: min(1., float(step) / cfg.TRAIN.WARMUP) * (1 - float(step) / cfg.TRAIN.STEPS) ** cfg.TRAIN.POWER,
                                                    last_epoch=-1)
        else:
            scheduler = optim.lr_scheduler.LambdaLR(optimizer,
                                                    lambda step: (1 - float(step) / cfg.TRAIN.STEPS) ** cfg.TRAIN.POWER,
                                                    last_epoch=-1)
    elif cfg.TRAIN.SCHEDULE == 'Cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.TRAIN.STEPS)
    else:
        raise NotImplementedError
        
    if cfg.TRAIN.APEX:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    model.train()
    steps = 0
    while steps < cfg.TRAIN.STEPS:
        for batch_idx, (image, label_1, label_2) in enumerate(train_loader):
            if cfg.CUDA:
                image, label_1, label_2 = image.cuda(), label_1.cuda(), label_2.cuda()
            optimizer.zero_grad()

            result = model.loss(image, (label_1, label_2))
            out1, out2 = result.out1, result.out2

            loss1 = result.loss1
            loss2 = result.loss2

            loss = result.loss
            
            # Mixed Precision
            if cfg.TRAIN.APEX:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()
            model.step()  # update model step count
            scheduler.step()

            # Print out the loss periodically.
            if steps % cfg.TRAIN.LOG_INTERVAL == 0:
                printf('Train Step: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLoss1: {:.6f}\tLoss2: {:.6f}'.format(
                    steps, batch_idx * len(image), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.data.item(),
                    loss1.data.item(), loss2.data.item()))

                # Log to tensorboard
                writer.add_scalar('lr', scheduler.get_lr()[0], steps)
                writer.add_scalar('loss/overall', loss.data.item(), steps)
                task1.log_visualize(out1, label_1, loss1, writer, steps)
                task2.log_visualize(out2, label_2, loss2, writer, steps)
                writer.add_image('image', process_image(image[0], train_data.image_mean), steps)

            if steps % cfg.TRAIN.SAVE_INTERVAL == 0:
                checkpoint = {
                    'cfg': cfg,
                    'step': steps,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': loss,
                    'loss1': loss1,
                    'loss2': loss2,
                    'task1_metric': None,
                    'task2_metric': None,
                }

                if cfg.TRAIN.EVAL_CKPT:
                    model.eval()
                    task1_metric, task2_metric = evaluate(test_loader, model, task1, task2)
                    for k, v in task1_metric.items():
                        writer.add_scalar('eval/{}'.format(k), v, steps)
                    for k, v in task2_metric.items():
                        writer.add_scalar('eval/{}'.format(k), v, steps)
                    for k, v in task1_metric.items():
                        printf('{}: {:.3f}'.format(k, v))
                    for k, v in task2_metric.items():
                        printf('{}: {:.3f}'.format(k, v))

                    checkpoint['task1_metric'] = task1_metric
                    checkpoint['task2_metric'] = task2_metric
                    model.train()

                torch.save(checkpoint, os.path.join(cfg.SAVE_DIR, cfg.EXPERIMENT_NAME,
                                                    'ckpt-%s.pth' % str(steps).zfill(5)))

            if steps >= cfg.TRAIN.STEPS:
                break
            steps += 1
Exemple #7
0
def main():
    args = parse_args()

    train_dir = os.path.join(args.root, args.train)
    trainannot_dir = os.path.join(args.root, args.trainannot)
    val_dir = os.path.join(args.root, args.val)
    valannot_dir = os.path.join(args.root, args.valannot)

    train_ds = get_dataset(train_dir, trainannot_dir, batch_size=4)
    val_ds = get_dataset(val_dir, valannot_dir, batch_size=4)
    # val_ds = None
    fine_tune = args.finetune
    # fine_tune=False

    # MSI_FCN
    if args.model == 'msi_fcn':
        model_config = {
            "input_scales": 4,
            "dcu_gr": 16,
            "dense_gr": 24,
            "filters": 64,
            "expansion": 2,
            "msc_filters": [2, 2, 2, 2],
            "k": (7, 5, 3, 1),
            "up_filters": 2,
            "num_layers": (4, 4, 4, 4),
            "num_classes": 2,
            "use_msc": True,
            "use_up_block": False
        }
        model = MSI_FCN(**model_config)

    # FCN-VGG
    elif args.model == 'fcn':
        model_config = {"filters": 64, "expansion": 2, "num_classes": 2}
        model = FCN_vgg16(**model_config)

    # FCD
    elif args.model == 'fcd':
        model_config = {
            "growth_rate": 12,
            "td_filters": [48, 112, 192, 304, 464, 656, 896],
            "up_filters": [1088, 816, 578, 384, 256],
            "down_layers": [4, 4, 4, 4, 4, 4],
            "up_layers": [4, 4, 4, 4, 4],
            "num_classes": 2
        }
        model = FCD(**model_config)
    else:
        raise ValueError("args.model should be 'msi_fcn', 'fcn' or 'fcd'.")

    work_dir = args.work_dir
    # print model params
    # model.build(input_shape=(None,256,256,3))
    # print(model.summary())
    lr = tf.keras.optimizers.schedules.ExponentialDecay(2e-4, 5000, 0.95)
    optimizer = tf.keras.optimizers.Adam(lr)
    for k, v in model_config.items():
        print("{}: {}".format(k, v))
    fit(train_ds=train_ds,
        val_ds=val_ds,
        model=model,
        optimizer=optimizer,
        loss_func=WSCE,
        work_dir=work_dir,
        epochs=60,
        fine_tune=fine_tune)