Пример #1
0
def compute_data_files():
    '''Sort the data files in increasing order of complexity, then return the n least complex datapoints.'''
    data_files = []
    print('Loading structure metadata')
    structure_to_files = misc.load_file(
        'save_state/neutral/structure_to_files.pkl')
    all_structure_strs = list(structure_to_files.keys())
    # The number of commas in the structure_str is used as a proxy for complexity
    i = 0
    all_structure_strs.sort(key=lambda x: x.count(','))
    '''
    for structure_str in all_structure_strs:
        data_i=structure_to_files[structure_str]
        if ( "SHAPE" in structure_str) and  len(data_i)>10000 and len(data_i)<20000:

            data_files.extend(data_i)
        else:
            continue
        print(structure_str, len(structure_to_files[structure_str]))


    structure_str = '(PROGRESSION,SHAPE,TYPE)'#(XOR,SHAPE,SIZE)
    data_files.extend(structure_to_files[structure_str])
    print(structure_str, len(structure_to_files[structure_str]))
    '''
    return all_structure_strs
Пример #2
0
 def _create_train_info(self, args, init_train_info):
     import json
     train_info_path = os.path.join(self.logdir, 'train_info.txt')
     if os.path.isfile(train_info_path):
         self.train_info = json.loads(misc.load_file(train_info_path))
     else:
         self.train_info = init_train_info
Пример #3
0
def main(args):
    # Step 1: init data folders
    '''if os.path.exists('save_state/'+args.regime+'/normalization_stats.pkl'):
        print('Loading normalization stats')
        x_mean, x_sd = misc.load_file('save_state/'+args.regime+'/normalization_stats.pkl')
    else:
        x_mean, x_sd = preprocess.save_normalization_stats(args.regime)
        print('x_mean: %.3f, x_sd: %.3f' % (x_mean, x_sd))'''

    # Step 2: init neural networks
    print("network is:", args.net)
    if args.net == "resnet":
        model = resnet50(pretrained=True)
    elif args.net == 'wresnet':
        model = wresnet50(pretrained=True)
    elif args.net == "tmp":
        model = tmp()
    elif args.net == 'RN_mlp':
        model = WildRelationNet()
    elif args.net == 'ReasonNet':
        model = ReasonNet()
    elif args.net == 'ReaP':
        model = ReasonNet_p()
    elif args.net == 'Reap16':
        model = ReasonNet_p16()
    elif args.net == 'Reaap':
        model = ReasonNet_ap()
    elif args.net == 'RN_ap':
        model = RN_ap()
    elif args.net == 'RN_r':
        model = RN_r()
    elif args.net == 'esemble':
        model = esemble()
    elif args.net == 'RNap2':
        model = RNap2()
    elif args.net == 'rn_mlp':
        model = rn_mlp()
    elif args.net == 'Reab3p16':
        model = Reab3p16()
    elif args.net == 'b3pa':
        model = b3pa()
    elif args.net == "b3_plstm":
        model = b3_plstm()
    elif args.net == "b3_palstm":
        model = b3palstm()
    elif args.net == "nmn":
        model = nmn()
    elif args.net == "b3p3":
        model = b3p3()
    elif args.net == "multi3":
        model = multi3()
    elif args.net == "split":
        model = b3_split()
    if args.gpunum > 1:
        model = nn.DataParallel(model, device_ids=range(args.gpunum))
    if args.net != 'RN_r':
        model.apply(weights_init)
        print('weight initial')
    weights_path = args.path_weight
    if os.path.exists(weights_path) and args.restore:
        pretrained_dict = torch.load(weights_path)
        model_dict = model.state_dict()
        pretrained_dict1 = {}
        for k, v in pretrained_dict.items():
            if k in model_dict:
                pretrained_dict1[k] = v
                # print(k)
        model_dict.update(pretrained_dict1)
        model.load_state_dict(model_dict)
        # optimizer.load_state_dict(torch.load(optimizer_path))
        print('load weight')
    model.cuda()
    epoch_count = 1
    print(time.strftime('%H:%M:%S', time.localtime(time.time())), 'testing')

    print('Loading structure metadata')
    structure_to_files = misc.load_file(
        'save_state/neutral/structure_to_files.pkl')

    all_structure_strs = list(structure_to_files.keys())
    # The number of commas in the structure_str is used as a proxy for complexity

    accuracy_all = []
    for structure_str in all_structure_strs:
        data_files = []
        data_i = structure_to_files[structure_str]
        if ("SHAPE" in structure_str
            ) and len(data_i) > 10000:  # and len(data_i) < 20000:
            data_files.extend(data_i)
        else:
            continue
        test_files = [
            data_file for data_file in data_files if 'test' in data_file
        ]

        test_loader = torch.utils.data.DataLoader(Dataset(test_files),
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.numwork)

        since = time.time()
        model.eval()
        accuracy_epoch = []
        for x, y in test_loader:
            x, y = Variable(x).cuda(), Variable(y).cuda()
            pred = model(x)

            pred = pred.data.max(1)[1]
            correct = pred.eq(y.data).cpu().sum().numpy()
            accuracy = correct * 100.0 / len(y)
            accuracy_epoch.append(accuracy)
            accuracy_all.append(accuracy)

        acc = sum(accuracy_epoch) / len(accuracy_epoch)

        print(('epoch:%d, acc:%.1f') % (epoch_count, acc), "test_num:",
              len(test_files), (structure_str))
        epoch_count += 1

    print(('epoch:%d, acc:%.1f') %
          (epoch_count, sum(accuracy_all) / len(accuracy_all)))
Пример #4
0
from typing import Callable

import discord
from discord import PermissionOverwrite
from discord.ext import commands
from discord.ext.commands import Bot

from meme_collections import (add_to_collection, delete_from_collection,
                              get_from_collection, handle_collections)
from misc import get_messages, load_file, split_word_by_step, turn_into_emoji

# global vars
bot = Bot(command_prefix='~')

# data vars
CONFIG = load_file('data/config.json')
SPECIAL_IDS = load_file('data/ids.json')
SUPER_MODERATOR = SPECIAL_IDS[0]
last_message = ''

def needs_permission(bot: discord.ext.commands.Bot, hidden: bool=False):
    """
    Decorator to make a command that forces the user to have sufficient permissions to use the command    
    """
    def decorator(func: Callable, hidden=hidden):
        @bot.command(pass_context=True ,hidden=hidden)
        @wraps(func)
        async def wrapped(ctx, *args, **kwargs):
            if ctx.message.author.id == SUPER_MODERATOR or any(role.name == 'Moderator' for role in ctx.message.author.roles):
                return await func(ctx, *args, **kwargs)
            await bot.say("You don't have the permissions to run this command!")
Пример #5
0
def compute_data_files(regime, n, args):
    '''Sort the data files in increasing order of complexity, then return the n least complex datapoints.'''
    data_dir = '../data/' + regime + '/'
    if n == -1:
        data_files = os.listdir(data_dir)
        data_files = [data_dir + data_file for data_file in data_files]
        return data_files
    elif n == -2:
        test_files = [
            data_dir + data_file for data_file in os.listdir(data_dir)
            if 'test' in data_file
        ]
        data_files = []
        if os.path.exists('save_state/' + regime + '/structure_to_files.pkl'):
            print('Loading structure metadata')
            structure_to_files = misc.load_file('save_state/' + regime +
                                                '/structure_to_files.pkl')
        else:
            structure_to_files = save_structure_to_files(regime)
        all_structure_strs = list(structure_to_files.keys())
        all_structure_strs.sort(key=lambda x: x.count(','))
        for structure_str in all_structure_strs:
            data_i = structure_to_files[structure_str]
            if len(data_i) > 5000:
                data_i = data_i[:5000]
            data_files.extend(data_i)
            #print(structure_str, len(structure_to_files[structure_str]))
        data_files = [
            data_file for data_file in data_files if 'train' in data_file
        ]
        data_files.extend(test_files)
        return data_files
    elif n == -3:

        data_files = []
        if os.path.exists('save_state/' + regime + '/structure_to_files.pkl'):
            print('Loading structure metadata')
            structure_to_files = misc.load_file('save_state/' + regime +
                                                '/structure_to_files.pkl')
        else:
            structure_to_files = save_structure_to_files(regime)
        all_structure_strs = list(structure_to_files.keys())
        all_structure_strs.sort(key=lambda x: x.count(','))
        for structure_str in all_structure_strs:
            data_i = structure_to_files[structure_str]
            if len(data_i) > 20000 or len(data_i) < 10000:
                continue
            data_files.extend(data_i)
            print(structure_str, len(structure_to_files[structure_str]))

        return data_files
    elif n == -4:
        data_files = os.listdir("/home/zkc/reason/andshapecolormask/")
        data_files = [
            "/home/zkc/reason/andshapecolormask/" + data_file
            for data_file in data_files
        ]
        return data_files

    else:
        data_files = []
        if os.path.exists('save_state/' + regime + '/structure_to_files.pkl'):
            print('Loading structure metadata')
            structure_to_files = misc.load_file('save_state/' + regime +
                                                '/structure_to_files.pkl')
        else:
            structure_to_files = save_structure_to_files(regime)
        all_structure_strs = list(structure_to_files.keys())
        # The number of commas in the structure_str is used as a proxy for complexity

        i = 0
        all_structure_strs.sort(key=lambda x: x.count(','))

        for structure_str in all_structure_strs:

            data_i = structure_to_files[structure_str]
            if args.image_type == "image":
                if "SHAPE" in structure_str and "),(" not in structure_str:
                    data_files.extend(data_i)
                    print(structure_str, ":", len(data_i))
                if "LINE" in structure_str and "),(" not in structure_str:
                    data_files.extend(data_i)
                    print(structure_str, ":", len(data_i))
            elif args.image_type == "shape_im":
                if "SHAPE" in structure_str and "),(" not in structure_str:
                    data_files.extend(data_i)
                    print(structure_str, ":", len(data_i))
            elif args.image_type == "line_im":
                if "LINE" in structure_str and "),(" not in structure_str:
                    data_files.extend(data_i)
                    print(structure_str, ":", len(data_i))

        return data_files