Ejemplo n.º 1
0
def test_SinglePathLayer(args, myargs):
    from template_lib.utils import seed_utils
    seed_utils.set_random_seed(0)

    yaml_file = 'enas_cgan/configs/enas_cgan.yaml'
    with open(yaml_file, 'r') as f:
        configs = EasyDict(yaml.load(f))
    ops = configs.search_cgan_gen_cifar10_v6.model.generator.ops

    bs = 16
    layer_id = 0
    num_layers = 6
    in_planes = 3
    out_planes = 32
    n_classes = 10

    mixedlayer = SinglePathLayer(layer_id=layer_id,
                                 in_planes=in_planes,
                                 out_planes=out_planes,
                                 ops=ops,
                                 track_running_stats=False,
                                 scalesample=None,
                                 bn_type='none').cuda()
    bs = len(ops)
    x = torch.rand(bs, in_planes, 32, 32).cuda()
    # sample_arcs = torch.randint(0, len(ops), (bs, num_layers))
    # sample_arc = sample_arcs[:, layer_id]
    sample_arc = torch.arange(0, bs)
    out = mixedlayer(x, sample_arc=sample_arc)
    import torchviz
    g = torchviz.make_dot(out)
    g.view()
    pass
Ejemplo n.º 2
0
def train(args, myargs):
    try:
        rank = dist.get_rank()
        size = dist.get_world_size()
        args.rank = rank
        args.size = size
        if rank == 0:
            myargs.writer = SummaryWriter(logdir=args.tbdir)
    except:
        print('Do not use multiprocessing.')
    myargs.config = getattr(myargs.config, args.command)
    config = myargs.config
    seed_utils.set_random_seed(config.seed)
    trainer = trainer_dict[args.command](args=args, myargs=myargs)

    if args.evaluate:
        trainer.evaluate()
        return

    if args.resume:
        trainer.resume()
    elif args.finetune:
        trainer.finetune()

    # Load dataset
    trainer.dataset_load()

    trainer.train()
Ejemplo n.º 3
0
def test_MixedLayerCondSharedWeights(args, myargs):
    from template_lib.utils import seed_utils
    seed_utils.set_random_seed(0)

    yaml_file = 'enas_cgan/configs/enas_cgan.yaml'
    with open(yaml_file, 'r') as f:
        configs = EasyDict(yaml.load(f))
    ops = configs.search_cgan_gen_cifar10_v5.model.generator.ops

    bs = 16
    in_planes = 3
    out_planes = 64
    layer_id = 0
    num_layers = 6
    n_classes = 10
    shared_dim = 128
    z_chunk_size = 120 // 4
    G_shared = True
    which_linear = functools.partial(layers.SNLinear,
                                     num_svs=1,
                                     num_itrs=1,
                                     eps=1e-6)
    which_embedding = nn.Embedding
    bn_linear = (functools.partial(which_linear, bias=False)
                 if G_shared else which_embedding)
    which_bn = functools.partial(
        layers.ccbn,
        which_linear=bn_linear,
        cross_replica=False,
        mybn=False,
        input_size=(shared_dim + z_chunk_size if G_shared else n_classes),
        norm_style='bn',
        eps=1e-5)
    mixedlayer = MixedLayerCondSharedWeights(layer_id=layer_id,
                                             in_planes=in_planes,
                                             out_planes=out_planes,
                                             ops=ops,
                                             track_running_stats=False,
                                             scalesample=None,
                                             which_bn=which_bn).cuda()
    x = torch.rand(bs, in_planes, 32, 32).cuda()
    y = torch.randint(0, n_classes, (bs, )).cuda()
    shared = (which_embedding(n_classes, shared_dim)
              if G_shared else layers.identity()).cuda()

    shared_y = shared(y)
    shared_y = torch.cat((shared_y, torch.rand(bs, z_chunk_size).cuda()),
                         dim=1)
    sample_arcs = torch.randint(0, len(ops), (bs, num_layers))
    out = mixedlayer(x, y=shared_y, sample_arc=sample_arcs[:, layer_id])
    import torchviz
    g = torchviz.make_dot(out)
    g.view()
    pass
Ejemplo n.º 4
0
def main(trainer, args, myargs):
    config = myargs.config

    from template_lib.utils import seed_utils
    seed_utils.set_random_seed(config.seed)

    if args.evaluate:
        trainer.evaluate()
        return

    if args.resume:
        trainer.resume()
    elif args.finetune:
        trainer.finetune()

    # Load dataset
    trainer.dataset_load()

    trainer.train()
Ejemplo n.º 5
0
def run_trainer(args, myargs):
    myargs.config = getattr(myargs.config, args.command)
    config = myargs.config
    seed_utils.set_random_seed(config.seed)
    trainer = trainer_dict[args.command](args=args, myargs=myargs)

    if args.evaluate:
        trainer.evaluate()
        return

    if args.resume:
        trainer.resume()
    elif args.finetune:
        trainer.finetune()

    # Load dataset
    trainer.dataset_load()

    trainer.train()
Ejemplo n.º 6
0
def train(args, myargs):
    myargs.config = getattr(myargs.config, args.command)
    config = myargs.config
    print(pformat(OrderedDict(config)))
    trainer = trainer_dict[args.command](args=args, myargs=myargs)

    seed_utils.set_random_seed(config.seed)

    if args.evaluate:
        trainer.evaluate()
        return

    if args.resume:
        trainer.resume()
    elif args.finetune:
        trainer.finetune()

    # Load dataset
    trainer.dataset_load()

    trainer.train()
Ejemplo n.º 7
0
def main(args, myargs):
    config = myargs.config.main
    logger = myargs.logger

    from template_lib.utils import seed_utils
    seed_utils.set_random_seed(config.seed)

    # Create train_dict
    train_dict = init_train_dict()
    myargs.checkpoint_dict['train_dict'] = train_dict

    # Create trainer
    trainer = trainer_create(args=args, myargs=myargs)

    if args.evaluate:
        trainer.evaluate()
        return

    if args.resume:
        logger.info('=> Resume from: %s', args.resume_path)
        loaded_state_dict = myargs.checkpoint.load_checkpoint(
            checkpoint_dict=myargs.checkpoint_dict,
            resumepath=args.resume_path)
        for key in train_dict:
            if key in loaded_state_dict['train_dict']:
                train_dict[key] = loaded_state_dict['train_dict'][key]

    # Load dataset
    trainer.dataset_load()

    for epoch in range(train_dict['epoch_done'], config.epochs):
        logger.info('epoch: [%d/%d]' % (epoch, config.epochs))

        trainer.train_one_epoch()

        train_dict['epoch_done'] += 1
        # test
        trainer.test()
    trainer.finalize()