Пример #1
0
def main(parser, name, load_valid_test_loader, load_model, eval_model):
    args = parser.parse_args()
    print(args)
    if args.use_wandb:
        wandb.init(project='refinenet-pytorch', name=name, config=args, dir='/home/user/research/refinenet-pytorch/train')

    run_path = 'yonyeoseok/refinenet-pytorch/{}'.format(args.restore_run_id)
    print(run_path)
    with wandb.restore('wandb-metadata.json', run_path=run_path, root='.') as f:
        metadata = json.load(f)
        assert metadata['name'] == 'training', metadata['name']
        os.remove('wandb-metadata.json')
    with wandb.restore('config.yaml', run_path=run_path, root='.') as f:
        config = yaml.load(f, Loader=yaml.BaseLoader)
        os.remove('config.yaml')

    valid_dl, test_dl = load_valid_test_loader(args)
    # valid_dl.dataset.indices = valid_dl.dataset.indices[:18]
    # test_dl.dataset.images = test_dl.dataset.images[:3]
    print('dataset loaded')
    model = load_model(args)
    print('model loaded')
    wandb_log = WandbLog(args.use_wandb)

    for epoch in range(int(config['total_epoch']['value'])+1):
        wandb_log.running_metrics_epoch_step = epoch
        state_dict_path = 'state_dict.{:02d}.pth'.format(epoch)
        with wandb.restore(state_dict_path, run_path=run_path) as f:
            state_dict = torch.load(f.name)
        model.load_state_dict(state_dict)
        eval_model(model, valid_dl, test_dl, wandb_log, args)
        if not args.use_wandb:
            os.remove(state_dict_path)
Пример #2
0
 def _restore_wandb_checkpoint(self) -> None:
     with wandb.restore('checkpoint_latest') as f:
         ckpt_file = f.read().strip()
     assert ckpt_file, "Can't resume wandb run: no checkpoint found!"
     wandb.restore(ckpt_file)
     wandb.restore('checkpoint_additional_data.pickle')
     self.checkpoint = Path(wandb.run.dir) / ckpt_file
Пример #3
0
def restore(opt):
    if not opt.restore:
        return None

    print('-' * 30)
    print(f'Restoring from {opt.wandb_runpath}/{opt.restore}...')
    try:
        if '/home' in opt.wandb_runpath:
            # Restore from local directory
            from shutil import copyfile
            copyfile(
                opt.wandb_runpath,
                os.path.join(wandb.run.dir, f'checkpoint_{opt.restore}.pth'))
        else:
            # Copy from a previous run to the current run directory
            wandb.restore(f'checkpoint_{opt.restore}.pth',
                          run_path=opt.wandb_runpath)

        # Load the checkpoint
        checkpoint = torch.load(
            os.path.join(wandb.run.dir, f'checkpoint_{opt.restore}.pth'))
        return checkpoint
    except Exception as e:
        print('Restoring failed :(', e)
        return None
Пример #4
0
def main(parser, name, load_valid_test_loader, load_feature_regression_model,
         eval_model):
    args = parser.parse_args()
    print(args)
    if args.use_wandb:
        wandb.init(project='refinenet-pytorch',
                   name=name,
                   config=args,
                   dir='/home/user/research/refinenet-pytorch/train')

    run_path = 'yonyeoseok/refinenet-pytorch/{}'.format(args.restore_run_id)
    print(run_path)
    with wandb.restore('wandb-metadata.json', run_path=run_path,
                       root='.') as f:
        metadata = json.load(f)
        assert metadata['name'] == 'feature-regression-training', metadata[
            'name']
        os.remove('wandb-metadata.json')
    with wandb.restore('config.yaml', run_path=run_path, root='.') as f:
        config = yaml.load(f, Loader=yaml.BaseLoader)
        args.feature_layer = config['feature_layer']
        os.remove('config.yaml')

    valid_dl, test_dl = load_valid_test_loader(args)
    # valid_dl.dataset.indices = valid_dl.dataset.indices[:2]
    # test_dl.dataset.images = test_dl.dataset.images[:2]
    print('dataset loaded')
    feature_regression_model = load_feature_regression_model(args)
    print('model loaded')
    wandb_log = WandbLog(args.use_wandb)

    for epoch in range(int(config['total_epoch']['value']) + 1):
        wandb_log.running_metrics_epoch_step = epoch
        state_dict_path = 'state_dict.{:02d}.pth'.format(epoch)
        if args.use_wandb:
            for local_run_dir in os.listdir(os.path.dirname(wandb.run.dir)):
                if args.restore_run_id in local_run_dir:
                    if state_dict_path in os.listdir(
                            os.path.join(os.path.dirname(wandb.run.dir),
                                         local_run_dir)):
                        source_file = os.path.join(
                            os.path.dirname(wandb.run.dir), local_run_dir,
                            state_dict_path)
                        target_file = os.path.join(wandb.run.dir,
                                                   state_dict_path)
                        shutil.copyfile(source_file, target_file)
                    break
        with wandb.restore(state_dict_path, run_path=run_path) as f:
            state_dict = torch.load(f.name)
        wandb.save(state_dict_path)
        feature_regression_model.load_state_dict(state_dict)
        eval_model(feature_regression_model.model, valid_dl, test_dl,
                   wandb_log, args)
        if not args.use_wandb:
            os.remove(state_dict_path)
Пример #5
0
def get_coach_path(coach, coach_variant=None):
    if "cloned" in coach:
        coach_path = coach["best_coach"]
    else:
        coach_str = ("best_coach" if coach_variant is None else
                     f"best_coach_{coach_variant}")
        coach_path = wandb.restore(coach[coach_str],
                                   run_path=coach["run_path"]).name
        wandb.restore(coach[coach_str] + ".params", run_path=coach["run_path"])

    return coach_path
Пример #6
0
def load_model(run, G, D):
    file_G = wandb.restore("G.pth").name
    file_D = wandb.restore("D.pth").name

    if not torch.cuda.is_available():
        G.load_state_dict(torch.load(file_G, map_location={"cuda:0": "cpu"}))
        D.load_state_dict(torch.load(file_D, map_location={"cuda:0": "cpu"}))
    else:
        G.load_state_dict(torch.load(file_G))
        D.load_state_dict(torch.load(file_D))

    return G, D
Пример #7
0
def get_executor_path(executor, exec_variant=None):
    if "cloned" in executor:
        exec_path = executor["best_exec"]
    else:
        exec_str = "best_exec" if exec_variant is None else f"best_exec_{exec_variant}"
        exec_path = wandb.restore(executor[exec_str],
                                  run_path=executor["run_path"]).name
        wandb.restore(
            executor[exec_str] + ".params",
            run_path=executor["run_path"],
        )

    return exec_path
Пример #8
0
    def load_network(self, loaded_net=None):
        add_log = False
        if loaded_net is None:
            add_log = True
            if self.cfg.load.wandb_load_path is not None:
                self.cfg.load.network_chkpt_path = wandb.restore(
                    self.cfg.load.network_chkpt_path,
                    run_path=self.cfg.load.wandb_load_path,
                ).name
            loaded_net = torch.load(
                self.cfg.load.network_chkpt_path,
                map_location=torch.device(self.device),
            )
        loaded_clean_net = OrderedDict()  # remove unnecessary 'module.'
        for k, v in loaded_net.items():
            if k.startswith("module."):
                loaded_clean_net[k[7:]] = v
            else:
                loaded_clean_net[k] = v

        self.net.load_state_dict(loaded_clean_net,
                                 strict=self.cfg.load.strict_load)
        if is_logging_process() and add_log:
            self._logger.info("Checkpoint %s is loaded" %
                              self.cfg.load.network_chkpt_path)
Пример #9
0
def test_restore(runner, wandb_init_run, request_mocker, download_url, query_run_v2, query_run_files):
    with runner.isolated_filesystem():
        query_run_v2(request_mocker)
        query_run_files(request_mocker)
        download_url(request_mocker, size=10000)
        res = wandb.restore("weights.h5")
        assert os.path.getsize(res.name) == 10000
Пример #10
0
def __download_dataset(env_name: str, pre_trained: int = 1, no_cache=False):
    env_info = __get_env_info(env_name)
    run_path = env_info['wandb_run_path']
    dataset_name = 'dataset_{}.h5'.format(env_info['models'][pre_trained])
    dataset_root = os.path.join(env_name, POLICY_BAZAAR_DIR, env_name,
                                'pre_trained_{}'.format(pre_trained),
                                'dataset')
    os.makedirs(dataset_root, exist_ok=True)
    dataset_path = os.path.join(dataset_root, dataset_name)
    if not (os.path.exists(dataset_path)) or no_cache:
        wandb.restore(name=dataset_name,
                      run_path=run_path,
                      replace=True,
                      root=dataset_root)

    return dataset_path
Пример #11
0
def get_most_recent_model_file(
        wandb_run: WandbRun,
        wandb_ckpt_path='checkpoints/',
        model_name='',
        exclude=None,
        step_extractor=lambda fname: fname.split("_")[1].split(".")[0]):
    # Find checkpoints
    checkpoints = [
        file for file in wandb_run.files
        if file.name.startswith(wandb_ckpt_path.lstrip("/"))
    ]
    relevant_checkpoints = [e for e in checkpoints if model_name in e.name]
    if exclude:
        relevant_checkpoints = [
            e for e in relevant_checkpoints if exclude not in e.name
        ]
    # Grab the latest checkpoint
    latest_checkpoint = relevant_checkpoints[np.argmax(
        [int(step_extractor(e.name)) for e in relevant_checkpoints])]
    print(f"Retrieved checkpoint {latest_checkpoint.name}.")
    # Restore the model
    model_file = wandb.restore(latest_checkpoint.name,
                               run_path=wandb_run.path,
                               replace=True)

    return model_file
Пример #12
0
    def generic_builder(self, name, net, lr=1e-2, dropout_rate=0.2):
        cfg = self.cfg
        inputs = layers.Input(shape=cfg['img_shape'])
        x = img_augmentation(inputs)
        if cfg['transfer_learning'] and not cfg['dynamic_reorg_enabled']:
            model = net(include_top=False, input_tensor=x, weights='imagenet')

            # Freeze the pretrained weights
            model.trainable = False

            # Rebuild top
            x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
            x = layers.BatchNormalization()(x)
            top_dropout_rate = dropout_rate
            x = layers.Dropout(top_dropout_rate, name="top_dropout")(x)
            outputs = layers.Dense(cfg['num_classes'],
                                   activation="softmax",
                                   name="pred")(x)
            model = tf.keras.Model(inputs, outputs, name=name)
        elif cfg['dynamic_reorg_enabled']:
            model_path = cfg['dynamic_reorg_model_path']
            if cfg['dynamic_reorg_wandb_run_name'] != None:
                wandb_model = wandb.restore(
                    'model-best.h5', run_path="kzawora/lego-4h/runs/66z3b5tl")
                model_path = wandb_model.name
            model = tf.keras.models.load_model(model_path)
            model.trainable = False
            dense_input_tensor = model.output if cfg[
                'dynamic_reorg_preserve_top_layer'] else model.layers[-2].output
            outputs = layers.Dense(cfg['num_classes'],
                                   activation="softmax",
                                   name="pred_reorg")(dense_input_tensor)
            model = tf.keras.Model(model.input, outputs)
        else:
            model = net(include_top=True,
                        input_tensor=x,
                        weights=None,
                        classes=cfg['num_classes'])
            #model.trainable = True

            # Rebuild top


#            x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
#            x = layers.BatchNormalization()(x)
# top_dropout_rate = dropout_rate
# x = layers.Dropout(top_dropout_rate, name="top_dropout")(model.output)
#outputs = layers.Dense(
#    cfg['num_classes'], activation="softmax", name="pred")(x)

# Compile

        optimizer = self.get_optimizer(learning_rate=lr)
        model.compile(optimizer=optimizer,
                      loss="categorical_crossentropy",
                      metrics=["accuracy", "top_k_categorical_accuracy"])
        return model
Пример #13
0
def restore_files(run_path, folder):
    names = ['G.pth', 'G_opt.pth', 'D.pth',
             'D_opt.pth', 'Gs.pth', 'kwargs.json']
    result = {}
    for file_name in names:
        print(f'Downloading {file_name}')
        weight = wandb.restore(f'{folder}/{file_name}', run_path=run_path)
        key = file_name.split('.')[0]
        result[key] = weight.name
    return result
Пример #14
0
def get_state_dict_and_config(wandb_run_path,
                              config_file_name: str = 'exp_config.yaml',
                              model_file_name: str = 'model.pt'):
    ret = dict()
    if config_file_name is not None:
        model_config_path = wandb.restore(config_file_name,
                                          wandb_run_path,
                                          replace=True)
        config = Box.from_yaml(filename=model_config_path.name)
        ret['config'] = config

    if model_file_name is not None:
        model_path = wandb.restore(model_file_name,
                                   wandb_run_path,
                                   replace=True)
        model_state_dict = torch.load(model_path.name, map_location='cpu')
        ret['state_dict'] = model_state_dict

    return ret
Пример #15
0
def get_model_weights():
    """
    Retrieves the trained CNN (Mask classifier) weights from Weights and Biases

    Return:
        Returns the path to the best CNN model weights
    """
    best_model = wandb.restore("model-best.h5",
                               run_path="seedatnabeel/mask_cv_model/34y9teh1")
    return best_model
Пример #16
0
def load_model(idx):
    chpt = os.path.join(PROJECT, idx, 'checkpoints/last.ckpt')

    # load model from wandb
    best_model = wandb.restore(chpt,
                               run_path=os.path.join(USER, PROJECT, idx),
                               replace=True,
                               root='wandb')
    # load model from the saved chackpoint
    vae = StandardVAE.load_from_checkpoint(checkpoint_path=best_model.name)
    return vae
Пример #17
0
def get_trained_lang_embedding():
    """
    Retrieves the trained language embedding from Weights and Biases for the WHO - FAQS

    Return:
        Returns a numpy array of the language embedding
    """
    dill_file = wandb.restore("lang_embeddings.p",
                              run_path="seedatnabeel/qa_lang_model//g1fo1le9")
    # Loads the pickled file
    embed = dill.load(open(dill_file.name, "rb"))
    return embed
Пример #18
0
 def load_network(self, load_path, network, strict=True, wandb_load_run_path=None):
     if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
         network = network.module
     if wandb_load_run_path is not None:
         load_path = wandb.restore(load_path, run_path=wandb_load_run_path).name
     load_net = torch.load(load_path)
     load_net_clean = OrderedDict()  # remove unnecessary 'module.'
     for k, v in load_net.items():
         if k.startswith('module.'):
             load_net_clean[k[7:]] = v
         else:
             load_net_clean[k] = v
     network.load_state_dict(load_net_clean, strict=strict)
Пример #19
0
    def load_training_state(self, logger):
        if self.hp.log.use_wandb is not None:
            self.hp.load.resume_state_path = wandb.restore(
                self.hp.load.resume_state_path,
                run_path=self.hp.load.wandb_load_path).name
        resume_state = torch.load(self.hp.load.resume_state_path)

        self.load_network(loaded_clean_net=resume_state["model"],
                          logger=logger)
        self.optimizer.load_state_dict(resume_state["optimizer"])
        self.step = resume_state["step"]
        self.epoch = resume_state["epoch"]
        logger.info("Resuming from training state: %s" %
                    self.hp.load.resume_state_path)
Пример #20
0
def restore_model(file, storage='local', encoding='utf-8'):
    if storage == 'wandb':
        parts = file.split('/')
        wandb_path = '/'.join(parts[:-1])
        wandb_file = parts[-1]
        restore_file = wandb.restore(wandb_file, run_path=wandb_path)
        checkpoint = torch.load(restore_file.name, encoding=encoding)
    elif storage == 'local':  # local storage
        checkpoint = torch.load(file, encoding=encoding)
    else:
        print('Unknown storage type')
        checkpoint = None

    return checkpoint
Пример #21
0
def load_wandb_checkpoint(entity, project, run_id):
    api = wandb.Api()
    run = api.run("{}/{}/{}".format(entity, project, run_id))
    epoch = run.summary.epoch
    loaded = False
    while not loaded:
        if epoch < 0:
            raise Exception("No saved checkpoints.")
        try:
            checkpoint_path = wandb.restore(
                'checkpoints/checkpoint_{}.pt'.format(epoch),
                run_path="{}/{}/{}".format(entity, project, run_id)).name
            loaded = True
        except:
            epoch -= 1
    trainer = load_checkpoint(checkpoint_path)
    return trainer
Пример #22
0
    def load_network(self, loaded_clean_net=None, logger=None):
        if loaded_clean_net is None:
            if self.hp.log.use_wandb and self.hp.load.wandb_load_path is not None:
                self.hp.load.network_chkpt_path = wandb.restore(
                    self.hp.load.network_chkpt_path,
                    run_path=self.hp.load.wandb_load_path,
                ).name
            loaded_net = torch.load(self.hp.load.network_chkpt_path)
            loaded_clean_net = OrderedDict()  # remove unnecessary 'module.'
            for k, v in loaded_net.items():
                if k.startswith("module."):
                    loaded_clean_net[k[7:]] = v
                else:
                    loaded_clean_net[k] = v

        self.net.load_state_dict(loaded_clean_net, strict=self.hp.load.strict_load)
        if logger is not None:
            logger.info("Checkpoint %s is loaded" % self.hp.load.network_chkpt_path)
Пример #23
0
    def load_training_state(self):
        if self.cfg.load.wandb_load_path is not None:
            self.cfg.load.resume_state_path = wandb.restore(
                self.cfg.load.resume_state_path,
                run_path=self.cfg.load.wandb_load_path,
            ).name
        resume_state = torch.load(
            self.cfg.load.resume_state_path,
            map_location=torch.device(self.device),
        )

        self.load_network(loaded_net=resume_state["model"])
        self.optimizer.load_state_dict(resume_state["optimizer"])
        self.step = resume_state["step"]
        self.epoch = resume_state["epoch"]
        if is_logging_process():
            self._logger.info("Resuming from training state: %s" %
                              self.cfg.load.resume_state_path)
def main():
    args = parse_args()
    assert args.output.endswith(".pth")

    checkpoint = args.checkpoint
    if path.exists(checkpoint):
        ck = torch.load(checkpoint, map_location=torch.device('cpu'))
    else:
        # if not checkpoint is not valid path, check for wandb
        restored_model = wandb.restore('latest.pth', run_path=f"{checkpoint}", replace=False)
        if restored_model is None:
            raise Exception(f"failed to load the model from runid or path: {checkpoint} ")
        ck = torch.load(restored_model.name, map_location=torch.device('cpu'))

    output_dict = dict(state_dict=dict(), author="OpenSelfSup")
    has_backbone = False
    for key, value in ck['state_dict'].items():
        if key.startswith('backbone'):
            output_dict['state_dict'][key[9:]] = value
            has_backbone = True
    if not has_backbone:
        raise Exception("Cannot find a backbone module in the checkpoint.")
    torch.save(output_dict, args.output)
Пример #25
0
def main(argv):

    argdict = {
        flag.name: flag.value
        for flag in FLAGS.flags_by_module_dict()['args']
    }
    pprint(argdict)
    argdict['tags'] = argdict['tags'].split(',')
    wandb.init(project='extreme-classification',
               name=argdict['name'],
               tags=argdict['tags'])
    wandb.config.update(argdict)

    device = get_device(FLAGS.cuda, FLAGS.cuda_device)

    if FLAGS.mode == 'train':
        if FLAGS.preprocess_mode == 'synth':
            train_data, dev_data, test_data = synth_gen(
                FLAGS.synth_mode, FLAGS)
            class_map = {
                idx: idx
                for idx in set(train_data.labels + dev_data.labels)
            }

        else:
            preprocessor = Preprocessor(FLAGS.preprocess_mode)
            train_data = preprocessor.preprocess_file(FLAGS.train_path)
            dev_data = preprocessor.preprocess_file(FLAGS.dev_path)

        class_map = {
            label: idx
            for idx, label in enumerate(
                sorted(set(train_data.labels + dev_data.labels)))
        }
        model = VecModel(train_data[0][0].shape[0],
                         class_map,
                         None,
                         enable_binary=FLAGS.enable_binary)
        model.init_norms(train_data)

        trainer = Trainer(FLAGS)
        train_tmp_dir = tempfile.TemporaryDirectory()
        os.chdir(train_tmp_dir.name)
        trainer.train(train_data,
                      dev_data,
                      model,
                      device=device,
                      train_tmp_dir='.')

        if FLAGS.model_metric is not None:
            model.load_state_dict(
                torch.load(f'model_{FLAGS.model_metric}.state'))
        if FLAGS.preprocess_mode == 'synth':
            run_test(model, test_data, device)
        elif FLAGS.test_path is not None:
            test_data = preprocessor.preprocess_file(FLAGS.test_path)
            run_test(model, test_data, device)

    else:  # test mode
        test_data = preprocessor.preprocess_file(FLAGS.test_path)
        test_tmp_dir = tempfile.TemporaryDirectory()
        wandb.restore('model.pkl',
                      run_path=FLAGS.model_run_path,
                      root=test_tmp_dir.name)
        pickle_f_bytes = open(os.path.join(test_tmp_dir.name, 'model.pkl'),
                              'rb').read()
        params = pickle.loads(pickle_f_bytes)
        model = VecModel(*params)
        state_name = f'model_{FLAGS.model_metric}.state'
        wandb.restore(state_name,
                      run_path=FLAGS.model_run_path,
                      root=test_tmp_dir.name)
        model.load_state_dict(
            torch.load(os.path.join(test_tmp_dir.name, state_name)))
        run_test(model, test_data, device)
Пример #26
0
X_train = X_train.astype('float32')
X_train /= 255.
X_test = X_test.astype('float32')
X_test /= 255.

X_train = X_train.reshape(X_train.shape[0], img_width, img_height, 1)
X_test = X_test.reshape(X_test.shape[0], img_width, img_height, 1)

y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_test.shape[1]

# Check to see if the run was resumed, if it was, load the best model
if wandb.run.resumed:
    print("Resuming model with config: {}".format(dict(config)))
    model = load_model(wandb.restore("model-best.h5").name)
else:
    sgd = SGD(lr=config.learn_rate,
              decay=config.decay,
              momentum=config.momentum,
              nesterov=True)
    model = Sequential()
    model.add(
        Conv2D(config.layer_1_size, (5, 5),
               activation='relu',
               input_shape=(img_width, img_height, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(config.layer_2_size, (5, 5), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(config.dropout))
    model.add(Flatten())
# %%
from google.colab import auth
auth.authenticate_user()

PROJECT_ID = "fast-ai-exploration"
get_ipython().system('gcloud config set project $PROJECT_ID')

# %%
get_ipython().system(
    'gsutil cp gs://resnet_simclr_imagenet/20200508-134915resnet_simclr.h5 .')

# %% [markdown]
# ### Restoring model weights from `wandb` run page

# %%
simclr_weights = wandb.restore("20200509-042927resnet_simclr.h5",
                               run_path="sayakpaul/simclr/simclr-learning")

# %%
# Other imports
from sklearn.preprocessing import LabelEncoder
from sklearn.manifold import TSNE
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from wandb.keras import WandbCallback
import matplotlib.pyplot as plt
from imutils import paths
from tqdm import tqdm
import tensorflow as tf
import seaborn as sns
import numpy as np
import cv2
Пример #28
0
def test_restore_name_not_found(runner, mock_server, wandb_init_run):
    with runner.isolated_filesystem():
        with pytest.raises(ValueError):
            wandb.restore("nofile.h5")
Пример #29
0
def test_restore(runner, mock_server, wandb_init_run):
    with runner.isolated_filesystem():
        mock_server.set_context("files", {"weights.h5": 10000})
        res = wandb.restore("weights.h5")
        assert os.path.getsize(res.name) == 10000
Пример #30
0
def test_restore_no_init(runner, mock_server):
    with runner.isolated_filesystem():
        mock_server.set_context("files", {"weights.h5": 10000})
        res = wandb.restore("weights.h5", run_path="foo/bar/baz")
        assert os.path.getsize(res.name) == 10000