Пример #1
0
def define_transformer(version, input_nc, input_width, input_height, **kwargs):
    if version == 1:  # Identity function
        return models.Identity(input_nc, input_width, input_height, **kwargs)
    elif version == 2:  # 1 conv layer
        return models.JustConv(input_nc, input_width, input_height, **kwargs)
    elif version == 3:  # 1 conv layer + 1 max pooling
        return models.ConvPool(input_nc, input_width, input_height, **kwargs)
    elif version == 4:  # Bottle-neck residual block
        return models.ResidualTransformer(input_nc, input_width, input_height, **kwargs)
    elif version == 5:  # VGG13: 2 conv layer + 1 max pooling
        return models.VGG13ConvPool(input_nc, input_width, input_height, **kwargs)
    else:
        raise NotImplementedError(
            "Specified transformer module not available.")
def define_node(
    args,
    node_index,
    level,
    parent_index,
    tree_struct,
    identity=False,
):
    """ Define node operations.
    
    In this function, we assume that 3 building blocks of node operations
    i.e. transformer, solver and router are of fixed complexity. 
    """

    # define meta information
    num_transforms = 0 if node_index == 0 else count_number_transforms(
        parent_index, tree_struct)
    meta = {
        'index': node_index,
        'parent': parent_index,
        'left_child': 0,
        'right_child': 0,
        'level': level,
        'extended': False,
        'split': False,
        'visited': False,
        'is_leaf': True,
        'train_accuracy_gain_split': -np.inf,
        'valid_accuracy_gain_split': -np.inf,
        'test_accuracy_gain_split': -np.inf,
        'train_accuracy_gain_ext': -np.inf,
        'valid_accuracy_gain_ext': -np.inf,
        'test_accuracy_gain_ext': -np.inf,
        'num_transforms': num_transforms
    }

    # get input shape before transformation
    if not tree_struct:  # if it's first node, then set it to the input data size
        meta['in_shape'] = (1, args.input_nc, args.input_width,
                            args.input_height)
    else:
        meta['in_shape'] = tree_struct[parent_index]['out_shape']

    # -------------------------- define transformer ---------------------------
    # no transformation if the input size is too small.
    if meta['in_shape'][2] < 3 or meta['in_shape'][3] < 3:
        identity = True

    if identity or args.transformer_ver == 1:
        meta['transformed'] = False
    else:
        meta['transformed'] = True

    # only downsample at the specified frequency:
    # currently assume the initial transform always perform downsampling.
    num_downsample = 0 if node_index == 0 else count_number_transforms_after_last_downsample(
        parent_index, tree_struct)
    if args.downsample_interval == num_downsample or node_index == 0:
        meta['downsampled'] = True
    else:
        meta['downsampled'] = False

    # get the transformer version:
    config_t = {
        'kernel_size': args.transformer_k,
        'ngf': args.transformer_ngf,
        'batch_norm': args.batch_norm,
        'downsample': meta['downsampled'],
        'expansion_rate': args.transformer_expansion_rate,
        'reduction_rate': args.transformer_reduction_rate
    }
    transformer_ver = args.transformer_ver
    if identity:
        transformer = models.Identity(meta['in_shape'][1], meta['in_shape'][2],
                                      meta['in_shape'][3], **config_t)
    else:
        transformer = define_transformer(transformer_ver, meta['in_shape'][1],
                                         meta['in_shape'][2],
                                         meta['in_shape'][3], **config_t)
    meta['identity'] = identity

    # get output shape after transformation:
    meta['out_shape'] = transformer.outputshape
    print('---------------- data shape before/after transformer -------------')
    print(meta['in_shape'], type(meta['in_shape']))
    print(meta['out_shape'], type(meta['out_shape']))

    # ---------------------------- define solver-------------------------------
    config_s = {
        'no_classes': args.no_classes,
        'dropout_prob': args.solver_dropout_prob,
        'batch_norm': args.batch_norm
    }
    solver = define_solver(args.solver_ver, meta['out_shape'][1],
                           meta['out_shape'][2], meta['out_shape'][3],
                           **config_s)

    # ---------------------------- define router ------------------------------
    config_r = {
        'kernel_size': args.router_k,
        'ngf': args.router_ngf,
        'soft_decision': True,
        'stochastic': False,
        'dropout_prob': args.router_dropout_prob,
        'batch_norm': args.batch_norm
    }

    router = define_router(args.router_ver, meta['out_shape'][1],
                           meta['out_shape'][2], meta['out_shape'][3],
                           **config_r)

    # define module:
    module = {'transform': transformer, 'classifier': solver, 'router': router}

    return meta, module
Пример #3
0
    def comp_arch_rand_sfn(self):
        def shrink_n(F, ratio):
            m = opt.ar_channel_mul
            return max(1, int(ceil((1.0 - ratio) * F / m))) * m

        arch = copy.deepcopy(self)
        n = arch.n
        V = arch.V

        p1 = random.choice(opt.ar_p1)
        for i in range(n):
            if (random.random() < p1 and V[i].in_shape == V[i].out_shape
                    and i not in [11, 50, 125]):
                V[i].replace(models.Identity())

        opt.ar_p2[1] = min(0.9, opt.ar_p2[1])
        for g in self.groups:
            p2 = random.uniform(*opt.ar_p2)
            for j in g.inter:
                Fi = shrink_n(list(V[j].in_shape)[1], p2)
                Fo = shrink_n(list(V[j].out_shape)[1], p2)
                V[j].shrink(Fi, Fo)
            for j in g.in_only:
                Fi = shrink_n(list(V[j].in_shape)[1], p2)
                Fo = list(V[j].out_shape)[1]
                V[j].shrink(Fi, Fo)
            for j in g.out_only:
                Fi = list(V[j].in_shape)[1]
                Fo = shrink_n(list(V[j].out_shape)[1], p2)
                V[j].shrink(Fi, Fo)

        F1 = list(V[2].out_shape)[1]
        F4 = list(V[13].out_shape)[1]
        F3 = list(V[13].in_shape)[1]
        F2 = F4 - F3
        V[11].shrink(F1, F2)
        V[12].shrink(F2, F2)

        F1 = list(V[41].out_shape)[1]
        F4 = list(V[52].out_shape)[1]
        F3 = list(V[52].in_shape)[1]
        F2 = F4 - F3
        V[50].shrink(F1, F2)
        V[51].shrink(F2, F2)

        F1 = list(V[116].out_shape)[1]
        F4 = list(V[127].out_shape)[1]
        F3 = list(V[127].in_shape)[1]
        F2 = F4 - F3
        V[125].shrink(F1, F2)
        V[126].shrink(F2, F2)

        p3 = random.choice(opt.ar_p3)
        for i in range(n):
            for j in range(i + 1, n):
                if (random.random() < p3 and V[i].out_shape == V[j].in_shape
                        and not isinstance(V[j].base,
                                           (models.Concat, models.Identity))):
                    arch.E[i][j] = True

        arch.in_links, arch.out_links = gr.get_links(arch.E)
        arch.init_rep()
        arch.to(opt.device)
        return arch