Exemplo n.º 1
0
def load_execution_engine(path, verbose=True):
    checkpoint = load_cpu(path)
    model_type = checkpoint['args']['model_type']
    kwargs = checkpoint['execution_engine_kwargs']
    state = checkpoint['execution_engine_state']
    kwargs['verbose'] = verbose
    if model_type == 'FiLM':
        kwargs = get_updated_args(kwargs, FiLMedNet)
        model = FiLMedNet(**kwargs)
    elif model_type == 'EE':
        model = ModuleNet(**kwargs)
    elif model_type == 'MAC':
        kwargs.setdefault('write_unit', 'original')
        kwargs.setdefault('read_connect', 'last')
        kwargs.setdefault('noisy_controls', False)
        kwargs.pop('sharing_params_patterns', None)
        model = MAC(**kwargs)
    elif model_type == 'RelNet':
        model = RelationNet(**kwargs)
    elif model_type == 'SHNMN':
        model = SHNMN(**kwargs)
    else:
        raise ValueError()
    cur_state = model.state_dict()
    model.load_state_dict(state)
    return model, kwargs
Exemplo n.º 2
0
def load_execution_engine(path, verbose=True):
    checkpoint = load_cpu(path)
    if checkpoint['args'].get('symbolic_ee'):
        vocab = load_vocab(checkpoint['args']['vocab_json'])
        ee = ClevrExecutor(vocab)
        return ee, {}
    model_type = checkpoint['args']['model_type']
    kwargs = checkpoint['execution_engine_kwargs']
    state = checkpoint['execution_engine_state']
    kwargs['verbose'] = verbose
    if model_type == 'FiLM':
        model = FiLMedNet(**kwargs)
    elif model_type in ['PG+EE', 'EE', 'Control-EE']:
        kwargs.pop('sharing_patterns', None)
        kwargs.setdefault('module_pool', 'mean')
        kwargs.setdefault('module_use_gammas', 'linear')
        model = ModuleNet(**kwargs)
    elif model_type == 'MAC':
        kwargs.setdefault('write_unit', 'original')
        kwargs.setdefault('read_connect', 'last')
        kwargs.setdefault('read_unit', 'original')
        kwargs.setdefault('noisy_controls', False)
        kwargs.pop('sharing_params_patterns', None)
        model = MAC(**kwargs)
    elif model_type == 'RelNet':
        model = RelationNet(**kwargs)
    elif model_type == 'SHNMN':
        model = SHNMN(**kwargs)
    elif model_type == 'SimpleNMN':
        model = SimpleModuleNet(**kwargs)
    else:
        raise ValueError()
    cur_state = model.state_dict()
    model.load_state_dict(state)
    return model, kwargs
Exemplo n.º 3
0
def load_execution_engine(path, verbose=True, model_type='PG+EE'):
    checkpoint = load_cpu(path)
    kwargs = checkpoint['execution_engine_kwargs']
    state = checkpoint['execution_engine_state']
    kwargs['verbose'] = verbose
    if model_type == 'FiLM':
        print('Loading FiLMedNet from ' + path)
        kwargs = get_updated_args(kwargs, FiLMedNet)
        model = FiLMedNet(**kwargs)
    else:
        print('Loading EE from ' + path)
        model = ModuleNet(**kwargs)
    cur_state = model.state_dict()
    model.load_state_dict(state)
    return model, kwargs
Exemplo n.º 4
0
def load_execution_engine(path, verbose=True, model_type="PG+EE"):
    checkpoint = load_cpu(path)
    kwargs = checkpoint["execution_engine_kwargs"]
    state = checkpoint["execution_engine_state"]
    kwargs["verbose"] = verbose
    if model_type == "FiLM":
        print("Loading FiLMedNet from " + path)
        kwargs = get_updated_args(kwargs, FiLMedNet)
        model = FiLMedNet(**kwargs)
    else:
        print("Loading EE from " + path)
        model = ModuleNet(**kwargs)
    cur_state = model.state_dict()
    model.load_state_dict(state)
    return model, kwargs
Exemplo n.º 5
0
def load_execution_engine(path, verbose=True, model_type='PG+EE'):
    checkpoint = load_cpu(path)
    kwargs = checkpoint['execution_engine_kwargs']
    state = checkpoint['execution_engine_state']
    kwargs['verbose'] = verbose
    if model_type == 'FiLM':
        print('Loading FiLMedNet from ' + path)
        kwargs = get_updated_args(kwargs, FiLMedNet)
        model = FiLMedNet(**kwargs)
    else:
        print('Loading EE from ' + path)
        model = ModuleNet(**kwargs)
    cur_state = model.state_dict()
    # cur_params = dict(model.named_parameters())
    # print cur_params.keys()
    state_stemed = {}
    for k, v in state.iteritems():
        k_new = '.'.join(k.split('.')[1:])
        state_stemed[k_new] = v
    model.load_state_dict(state_stemed)
    return model, kwargs
Exemplo n.º 6
0
        'batch_size': args.batch_size,
        'max_samples': args.num_val_samples,
        'num_workers': args.loader_num_workers,
    }

    with ClevrDataLoader(**train_loader_kwargs) as train_loader, \
            ClevrDataLoader(**val_loader_kwargs) as val_loader:
        best_val_acc = 0.0
        for i in range(30):
            logger.info('Epoch ' + str(i))
            train_loss, train_acc = eval_epoch(train_loader,
                                               film_gen,
                                               filmed_net,
                                               opt=opt)
            valid_loss, valid_acc = eval_epoch(val_loader, film_gen,
                                               filmed_net)

            if train_loss.ndim == 1:
                train_loss = train_loss[0]
                valid_loss = valid_loss[0]
            logger.info("{}, {}, {}, {}".format(train_loss, train_acc,
                                                valid_loss, valid_acc))

            if valid_acc > best_val_acc:
                best_val_acc = valid_acc
                state = dict()
                state['film_gen'] = film_gen.state_dict()
                state['filmed_net'] = filmed_net.state_dict()
                state['args'] = args
                torch.save(state, os.path.join(exp_dir, 'model.pt'))