예제 #1
0
def init_callbacks(args, summary_writer, model, viz_sampler):
    """
    initializes callbacks for training and validation
    Parameters
    ----------
    args: Namespace
        arguments generated by argparse
    summary_writer: SummaryWriter
        object to use tensorboard for logging
    model: Module
        the neural network to be trained
    viz_sampler: ItemLoader
        sampler for data visualization

    Returns
    -------
    dict
        containing custom callback tuples

    """
    train_cbs = (RunningAverageMeter(prefix='train', name='loss'))
    val_cbs = (RunningAverageMeter(prefix='eval', name='loss'),
               ScalarMeterLogger(writer=summary_writer),
               ModelSaver(metric_names='eval/loss',
                          save_dir=args.snapshots,
                          conditions='min',
                          model=model),
               ImageSamplingVisualizer(generator_sampler=viz_sampler,
                                       writer=summary_writer,
                                       grid_shape=(args.grid_shape,
                                                   args.grid_shape)))
    return {
        'train': train_cbs,
        'eval': val_cbs,
    }
예제 #2
0
파일: train.py 프로젝트: MIPT-Oulu/Collagen
def main(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    data_dir = os.path.join(os.environ['PWD'], cfg.data_dir)

    train_ds, classes = get_cifar10(data_folder=data_dir, train=True)
    eval_ds, _ = get_cifar10(data_folder=data_dir, train=False)
    n_channels = 3

    criterion = torch.nn.CrossEntropyLoss()

    model = ResNet(in_channels=n_channels, n_features=64, drop_rate=0.3).to(device).half()
    optimizer = torch.optim.SGD(params=model.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.wd,
                                nesterov=True)

    # Tensorboard visualization
    log_dir = cfg.log_dir
    comment = cfg.comment
    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)

    item_loaders = dict()
    for stage, df in zip(['train', 'eval'], [train_ds, eval_ds]):
        item_loaders[f'loader_{stage}'] = ItemLoader(meta_data=df,
                                                     transform=my_transforms()[stage],
                                                     parse_item_cb=parse_item,
                                                     batch_size=cfg.bs, num_workers=cfg.num_workers,
                                                     shuffle=True if stage == "train" else False)

    data_provider = DataProvider(item_loaders)

    train_cbs = (CosineAnnealingWarmRestartsWithWarmup(optimizer=optimizer, warmup_epochs=(0, 10, 20),
                                                       warmup_lrs=(0, 0.1, 0.01), T_O=5, T_mult=2, eta_min=0),
                 RunningAverageMeter(name="loss"),
                 AccuracyMeter(name="acc"))

    val_cbs = (RunningAverageMeter(name="loss"),
               AccuracyMeter(name="acc"),
               ScalarMeterLogger(writer=summary_writer),
               ModelSaver(metric_names='loss', save_dir=cfg.snapshots, conditions='min', model=model),
               ModelSaver(metric_names='acc', save_dir=cfg.snapshots, conditions='max', model=model))

    session = dict()
    session['mymodel'] = Session(data_provider=data_provider,
                                 train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(),
                                 val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(),
                                 module=model, loss=criterion, optimizer=optimizer,
                                 train_callbacks=train_cbs,
                                 val_callbacks=val_cbs)

    strategy = Strategy(data_provider=data_provider,
                        data_sampling_config=cfg.sampling,
                        strategy_config=cfg.strategy,
                        sessions=session,
                        n_epochs=cfg.n_epochs,
                        device=device)

    strategy.run()
예제 #3
0
                                         sigmoid=False,
                                         prefix="train/D",
                                         name="ss_valid"),
                         SSAccuracyMeter(prefix="train/D", name="ss_acc"))

    d_callbacks_eval = (SSValidityMeter(threshold=0.5,
                                        sigmoid=False,
                                        prefix="eval/D",
                                        name="ss_valid"),
                        SSAccuracyMeter(prefix="eval/D", name="ss_acc"),
                        SSConfusionMatrixVisualizer(
                            writer=summary_writer,
                            labels=[str(i) for i in range(10)],
                            tag="eval/confusion_matrix"))

    st_callbacks = (ScalarMeterLogger(writer=summary_writer),
                    ImageSamplingVisualizer(
                        generator_sampler=item_loaders['fake_unlabeled_gen'],
                        writer=summary_writer,
                        grid_shape=args.grid_shape))

    with open("settings.yml", "r") as f:
        sampling_config = yaml.load(f)

    d_trainer = Trainer(
        data_provider=data_provider,
        train_loader_names=tuple(
            sampling_config["train"]["data_provider"]["D"].keys()),
        val_loader_names=tuple(
            sampling_config["eval"]["data_provider"]["D"].keys()),
        module=d_network,
예제 #4
0
def main(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    data_dir = os.path.join(os.environ['PWD'], cfg.data_dir)

    train_ds, n_classes = get_cifar10(data_folder=data_dir, train=True)
    eval_ds, _ = get_cifar10(data_folder=data_dir, train=False)
    criterion = SemixupLoss(in_manifold_coef=8.0,
                            in_out_manifold_coef=8.0,
                            ic_coef=16.0)

    # Tensorboard visualization
    log_dir = cfg.log_dir
    comment = cfg.comment
    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)

    model = DawnNet(in_channels=3, n_features=64,
                    drop_rate=cfg.dropout).to(device)
    if cfg.pretrained_model:
        print(f'Loading model: {cfg.pretrained_model}')
        model.load_state_dict(torch.load(cfg.pretrained_model))
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=cfg.lr,
                                momentum=cfg.momentum,
                                weight_decay=cfg.wd,
                                nesterov=True)
    # optimizer = torch.optim.Adam(params=model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
    # model = EfficientNet.from_name(cfg.arch_name).to(device)
    # if cfg.pretrained_model is not None and cfg.pretrained_model:
    #     print(f'Loading model: {cfg.pretrained_model}')
    #     model.load_state_dict(torch.load(cfg.pretrained_model))

    data_provider = semixup_data_provider(
        model=model,
        alpha=np.random.beta(cfg.alpha, cfg.alpha),
        n_classes=n_classes,
        train_labeled_data=train_ds,
        train_unlabeled_data=train_ds,
        val_labeled_data=eval_ds,
        transforms=my_transforms(),
        parse_item=parse_item,
        bs=cfg.bs,
        num_workers=cfg.num_workers,
        augmentation=my_transforms()['transforms'],
        data_rearrange=data_rearrange)

    train_cbs = (
        CosineAnnealingWarmRestartsWithWarmup(optimizer=optimizer,
                                              warmup_epochs=10,
                                              warmup_lrs=(1e-9, 0.1),
                                              T_O=5,
                                              T_mult=1,
                                              eta_min=0),
        # CycleRampUpDownScheduler(optimizer, initial_lr=0, rampup_epochs=15, rampup_lr=0.4,
        #                          start_cycle_epoch=20,
        #                          rampdown_epochs=25, cycle_interval=5, cycle_rampdown_epochs=0),
        # MultiLinearByBatchScheduler(optimizer=optimizer, n_batches=len(data_provider.get_loader_by_name('labeled_train')),
        #                             # steps=[0, 5, 10, 15, 20, 80], lrs=[0.05, 0.005, 0.01, 0.005, 0.001, 1e-8]),
        #                             steps=[0, 15, 30, 80], lrs=[0, 0.1, 0.005, 1e-8]),
        RunningAverageMeter(name="loss_cls"),
        RunningAverageMeter(name="loss_in_mnf"),
        RunningAverageMeter(name="loss_inout_mnf"),
        RunningAverageMeter(name="loss_ic"),
        AccuracyMeter(name="acc",
                      cond=cond_accuracy_meter,
                      parse_output=parse_class,
                      parse_target=parse_class))

    val_cbs = (RunningAverageMeter(name="loss_cls"),
               AccuracyMeter(name="acc",
                             cond=cond_accuracy_meter,
                             parse_output=parse_class,
                             parse_target=parse_class),
               ScalarMeterLogger(writer=summary_writer),
               ModelSaver(metric_names='loss_cls',
                          save_dir=cfg.snapshots,
                          conditions='min',
                          model=model),
               ModelSaver(metric_names='acc',
                          save_dir=cfg.snapshots,
                          conditions='max',
                          model=model))

    session = dict()
    session['mymodel'] = Session(
        data_provider=data_provider,
        train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(),
        val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(),
        module=model,
        loss=criterion,
        optimizer=optimizer,
        train_callbacks=train_cbs,
        val_callbacks=val_cbs)

    strategy = Strategy(data_provider=data_provider,
                        data_sampling_config=cfg.sampling,
                        strategy_config=cfg.strategy,
                        sessions=session,
                        n_epochs=cfg.n_epochs,
                        device=device)

    strategy.run()
예제 #5
0
        sampling_config = yaml.load(f)

    # Initializing the data provider
    data_provider = mt_data_provider(st_model=st_network,
                                     te_model=te_network,
                                     train_labeled_data=train_labeled_data,
                                     val_labeled_data=val_labeled_data,
                                     train_unlabeled_data=train_unlabeled_data,
                                     val_unlabeled_data=val_unlabeled_data,
                                     transforms=init_transforms(nc=n_channels),
                                     parse_item=parse_item,
                                     bs=args.bs,
                                     num_threads=args.num_threads,
                                     output_type='logits')
    # Setting up the callbacks
    stra_cbs = (ScalarMeterLogger(writer=summary_writer), ProgressbarLogger())

    # Trainers
    st_train_cbs = (
        CycleRampUpDownScheduler(
            optimizer=st_optim,
            initial_lr=args.initial_lr,
            rampup_epochs=args.lr_rampup,
            lr=args.lr,
            rampdown_epochs=args.lr_rampdown_epochs,
            start_cycle_epoch=args.start_cycle_epoch,
            cycle_interval=args.cycle_interval,
            cycle_rampdown_epochs=args.cycle_rampdown_epochs),
        # SingleRampUpDownScheduler(optimizer=st_optim, initial_lr=args.initial_lr, rampup_epochs=args.rampup_epochs,
        #                           lr=args.lr, rampdown_epochs=args.rampdown_epochs),
        RunningAverageMeter(prefix='train/S', name='loss_cls'),
예제 #6
0
파일: train.py 프로젝트: MIPT-Oulu/Collagen
def main(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    log_dir = os.path.join(os.getcwd(), cfg.log_dir)
    summary_writer = SummaryWriter(log_dir=log_dir, comment=cfg.comment)

    # Initializing Discriminator
    d_network = Discriminator(nc=1, ndf=cfg.d_net_features,
                              drop=cfg.dropout).to(device)
    d_optim = optim.Adam(d_network.parameters(),
                         lr=cfg.d_lr,
                         weight_decay=cfg.d_wd,
                         betas=(cfg.d_beta, 0.999))
    d_crit = BCELoss().to(device)

    # Initializing Generator
    g_network = Generator(nc=1, nz=cfg.latent_size,
                          ngf=cfg.g_net_features).to(device)
    g_optim = optim.Adam(g_network.parameters(),
                         lr=cfg.g_lr,
                         weight_decay=cfg.g_wd,
                         betas=(cfg.g_beta, 0.999))
    g_crit = GeneratorLoss(d_network=d_network, d_loss=d_crit).to(device)

    # Initializing the data provider
    item_loaders = dict()
    data_dir = os.path.join(os.environ['PWD'], cfg.data_dir)
    train_ds, classes = get_mnist(data_folder=data_dir, train=True)
    data_provider = gan_data_provider(g_network, item_loaders, train_ds,
                                      classes, cfg.latent_size,
                                      init_mnist_transforms(),
                                      parse_item_mnist_gan, cfg.bs,
                                      cfg.num_threads, device)

    # Setting up the callbacks
    st_callbacks = (SamplingFreezer([d_network, g_network]),
                    ScalarMeterLogger(writer=summary_writer),
                    ImageSamplingVisualizer(
                        generator_sampler=item_loaders['fake'],
                        transform=lambda x: (x + 1.0) / 2.0,
                        writer=summary_writer,
                        grid_shape=(cfg.grid_shape, cfg.grid_shape)))

    # Session
    d_session = Session(
        data_provider=data_provider,
        train_loader_names=cfg.sampling.train.data_provider.D.keys(),
        val_loader_names=None,
        train_callbacks=(BatchProcFreezer(modules=g_network),
                         RunningAverageMeter(prefix="train/D", name="loss")),
        module=d_network,
        optimizer=d_optim,
        loss=d_crit)

    g_session = Session(
        data_provider=data_provider,
        train_loader_names=cfg.sampling.train.data_provider.G.keys(),
        val_loader_names=cfg.sampling.eval.data_provider.G.keys(),
        train_callbacks=(BatchProcFreezer(modules=d_network),
                         RunningAverageMeter(prefix="train/G", name="loss")),
        val_callbacks=RunningAverageMeter(prefix="eval/G", name="loss"),
        module=g_network,
        optimizer=g_optim,
        loss=g_crit)

    sessions = {'D': d_session, 'G': g_session}

    # Strategy
    dcgan = Strategy(data_provider=data_provider,
                     data_sampling_config=cfg.sampling,
                     strategy_config=cfg.strategy,
                     sessions=sessions,
                     n_epochs=cfg.n_epochs,
                     callbacks=st_callbacks,
                     device=device)

    dcgan.run()
예제 #7
0
def worker_process(gpu, ngpus, sampling_config, strategy_config, args):
    args.gpu = gpu  # this line of code is not redundant
    if args.distributed:
        lr_m = float(args.batch_size * args.world_size) / 256.
    else:
        lr_m = 1.0
    criterion = torch.nn.CrossEntropyLoss().to(gpu)
    train_ds, classes = get_mnist(data_folder=args.save_data, train=True)
    test_ds, _ = get_mnist(data_folder=args.save_data, train=False)
    model = SimpleConvNet(bw=args.bw,
                          drop=args.dropout,
                          n_cls=len(classes),
                          n_channels=args.n_channels).to(gpu)
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=args.lr * lr_m,
                                 weight_decay=args.wd)

    args, model, optimizer = convert_according_to_args(args=args,
                                                       gpu=gpu,
                                                       ngpus=ngpus,
                                                       network=model,
                                                       optim=optimizer)

    item_loaders = dict()
    for stage, df in zip(['train', 'eval'], [train_ds, test_ds]):
        if args.distributed:
            item_loaders[f'mnist_{stage}'] = DistributedItemLoader(
                meta_data=df,
                transform=init_mnist_cifar_transforms(1, stage),
                parse_item_cb=parse_item_mnist,
                args=args)
        else:
            item_loaders[f'mnist_{stage}'] = ItemLoader(
                meta_data=df,
                transform=init_mnist_cifar_transforms(1, stage),
                parse_item_cb=parse_item_mnist,
                batch_size=args.batch_size,
                num_workers=args.workers,
                shuffle=True if stage == "train" else False)
    data_provider = DataProvider(item_loaders)
    if args.gpu == 0:
        log_dir = args.log_dir
        comment = args.comment
        summary_writer = SummaryWriter(log_dir=log_dir,
                                       comment='_' + comment + 'gpu_' +
                                       str(args.gpu))
        train_cbs = (RunningAverageMeter(prefix="train", name="loss"),
                     AccuracyMeter(prefix="train", name="acc"))

        val_cbs = (RunningAverageMeter(prefix="eval", name="loss"),
                   AccuracyMeter(prefix="eval", name="acc"),
                   ScalarMeterLogger(writer=summary_writer),
                   ModelSaver(metric_names='eval/loss',
                              save_dir=args.snapshots,
                              conditions='min',
                              model=model))
    else:
        train_cbs = ()
        val_cbs = ()

    strategy = Strategy(data_provider=data_provider,
                        train_loader_names=tuple(
                            sampling_config['train']['data_provider'].keys()),
                        val_loader_names=tuple(
                            sampling_config['eval']['data_provider'].keys()),
                        data_sampling_config=sampling_config,
                        loss=criterion,
                        model=model,
                        n_epochs=args.n_epochs,
                        optimizer=optimizer,
                        train_callbacks=train_cbs,
                        val_callbacks=val_cbs,
                        device=torch.device('cuda:{}'.format(args.gpu)),
                        distributed=args.distributed,
                        use_apex=args.use_apex)

    strategy.run()
예제 #8
0
파일: train.py 프로젝트: MIPT-Oulu/Collagen
def main(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    data_dir = os.path.join(os.environ['PWD'], cfg.data_dir)

    train_ds, classes = get_mnist(data_folder=data_dir, train=True)
    n_classes = len(classes)
    n_channels = 1

    criterion = torch.nn.CrossEntropyLoss()

    # Tensorboard visualization
    log_dir = cfg.log_dir
    comment = cfg.comment
    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)

    splitter = FoldSplit(train_ds, n_folds=5, target_col="target")

    for fold_id, (df_train, df_val) in enumerate(splitter):
        item_loaders = dict()

        for stage, df in zip(['train', 'eval'], [df_train, df_val]):
            item_loaders[f'loader_{stage}'] = ItemLoader(
                meta_data=df,
                transform=my_transforms()[stage],
                parse_item_cb=parse_item,
                batch_size=cfg.bs,
                num_workers=cfg.num_threads,
                shuffle=True if stage == "train" else False)

        model = SimpleConvNet(bw=cfg.bw,
                              drop_rate=cfg.dropout,
                              n_classes=n_classes).to(device)
        optimizer = torch.optim.Adam(params=model.parameters(),
                                     lr=cfg.lr,
                                     weight_decay=cfg.wd)
        data_provider = DataProvider(item_loaders)

        train_cbs = (RunningAverageMeter(name="loss"),
                     AccuracyMeter(name="acc"))

        val_cbs = (RunningAverageMeter(name="loss"), AccuracyMeter(name="acc"),
                   ScalarMeterLogger(writer=summary_writer),
                   ModelSaver(metric_names='loss',
                              save_dir=cfg.snapshots,
                              conditions='min',
                              model=model),
                   ModelSaver(metric_names='acc',
                              save_dir=cfg.snapshots,
                              conditions='max',
                              model=model))

        session = dict()
        session['mymodel'] = Session(
            data_provider=data_provider,
            train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(),
            val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(),
            module=model,
            loss=criterion,
            optimizer=optimizer,
            train_callbacks=train_cbs,
            val_callbacks=val_cbs)

        strategy = Strategy(data_provider=data_provider,
                            data_sampling_config=cfg.sampling,
                            strategy_config=cfg.strategy,
                            sessions=session,
                            n_epochs=cfg.n_epochs,
                            device=device)

        strategy.run()
예제 #9
0
    # Use teacher network here
    data_provider = mixmatch_ema_data_provider(
        model=te_network,
        labeled_meta_data=train_labeled_data,
        parse_item=parse_item,
        unlabeled_meta_data=train_unlabeled_data,
        bs=args.bs,
        augmentation=init_transforms(nc=n_channels)[2],
        n_augmentations=2,
        num_threads=args.num_threads,
        val_labeled_data=val_labeled_data,
        transforms=init_transforms(nc=n_channels))

    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)
    # Callbacks
    scheme_cbs = (ScalarMeterLogger(writer=summary_writer),
                  ProgressbarLogger())

    st_train_cbs = (RunningAverageMeter(prefix='train/S', name='loss_x'),
                    RunningAverageMeter(prefix='train/S', name='loss_u'),
                    ScalarMeterLogger(writer=summary_writer),
                    AccuracyMeter(prefix="train/S",
                                  name="acc",
                                  parse_target=parse_target,
                                  parse_output=parse_output,
                                  cond=cond_accuracy_meter),
                    KappaMeter(prefix='train/S',
                               name='kappa',
                               parse_target=parse_target_cls,
                               parse_output=parse_output_cls,
                               cond=cond_accuracy_meter),