def main():
    assert tf.executing_eagerly(), 'Only eager mode is supported.'
    assert args.config is not None, 'You need to pass in model config file path'
    assert args.data is not None, 'You need to pass in episode config file path'
    assert args.env is not None, 'You need to pass in environ config file path'
    assert args.tag is not None, 'You need to specify a tag'

    log.info('Command line args {}'.format(args))
    config = get_config(args.config, ExperimentConfig)
    data_config = get_config(args.data, EpisodeConfig)
    env_config = get_config(args.env, EnvironmentConfig)
    log.info('Model: \n{}'.format(config))
    log.info('Data episode: \n{}'.format(data_config))
    log.info('Environment: \n{}'.format(env_config))
    config.num_classes = data_config.nway  # Assign num classes.
    config.num_steps = data_config.maxlen
    config.memory_net_config.max_classes = data_config.nway
    config.memory_net_config.max_stages = data_config.nstage
    config.memory_net_config.max_items = data_config.maxlen
    config.oml_config.num_classes = data_config.nway
    config.fix_unknown = data_config.fix_unknown  # Assign fix unknown ID.
    log.info('Number of classes {}'.format(data_config.nway))
    if 'SLURM_JOB_ID' in os.environ:
        log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID']))

    # Create save folder.
    save_folder = os.path.join(env_config.results, env_config.dataset,
                               args.tag)
    results_file = os.path.join(save_folder, 'results.pkl')
    logfile = os.path.join(save_folder, 'results.csv')

    # To get CNN features from.
    model = build_pretrain_net(config)
    if args.rnn:
        rnn_memory = get_module("lstm",
                                "lstm",
                                model.backbone.get_output_dimension()[0],
                                config.lstm_config.hidden_dim,
                                layernorm=config.lstm_config.layernorm,
                                dtype=tf.float32)
        memory = RNNEncoder(
            "proto_plus_rnn_ssl_v4",
            rnn_memory,
            readout_type=config.mann_config.readout_type,
            use_pred_beta_gamma=config.hybrid_config.use_pred_beta_gamma,
            use_feature_fuse=config.hybrid_config.use_feature_fuse,
            use_feature_fuse_gate=config.hybrid_config.use_feature_fuse_gate,
            use_feature_scaling=config.hybrid_config.use_feature_scaling,
            use_feature_memory_only=config.hybrid_config.
            use_feature_memory_only,
            skip_unk_memory_update=config.hybrid_config.skip_unk_memory_update,
            use_ssl=config.hybrid_config.use_ssl,
            use_ssl_beta_gamma_write=config.hybrid_config.
            use_ssl_beta_gamma_write,
            use_ssl_temp=config.hybrid_config.use_ssl_temp,
            dtype=tf.float32)
        rnn_model = RNNEncoderNet(config, model.backbone, memory)
        f = lambda x: rnn_model.forward(x, is_training=False)  # NOQA
    else:
        f = lambda x: model.backbone(x[0], is_training=False)[None, :, :
                                                              ]  # NOQA

    K = config.num_classes
    # Get dataset.
    dataset = get_data_fs(env_config, load_train=True)

    # Get data iterators.
    if env_config.dataset in ["matterport"]:
        data = get_dataiter_sim(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=model.backbone.config.data_format == 'NCHW',
            seed=args.seed)
    else:
        data = get_dataiter_continual(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=model.backbone.config.data_format == 'NCHW',
            save_additional_info=True,
            random_box=data_config.random_box,
            seed=args.seed)
    if os.path.exists(results_file) and args.reeval:
        # Re-display results.
        results_all = pkl.load(open(results_file, 'rb'))
        for split, name in zip(['trainval_fs', 'val_fs', 'test_fs'],
                               ['Train', 'Val', 'Test']):
            stats = get_stats(results_all[split], tmax=data_config.maxlen)
            log_results(stats, prefix=name, filename=logfile)
    else:
        latest = latest_file(save_folder, 'weights-')

        if args.rnn:
            rnn_model.load(latest)
        else:
            model.load(latest)

        data['trainval_fs'].reset()
        data['val_fs'].reset()
        data['test_fs'].reset()

        results_all = {}
        if args.testonly:
            split_list = ['test_fs']
            name_list = ['Test']
            nepisode_list = [config.num_episodes]
        else:
            split_list = ['trainval_fs', 'val_fs', 'test_fs']
            name_list = ['Train', 'Val', 'Test']
            nepisode_list = [600, config.num_episodes, config.num_episodes]

        for split, name, N in zip(split_list, name_list, nepisode_list):
            data[split].reset()
            r1 = evaluate(f, K, data[split], N)
            stats = get_stats(r1, tmax=data_config.maxlen)
            log_results(stats, prefix=name, filename=logfile)
            results_all[split] = r1
        pkl.dump(results_all, open(results_file, 'wb'))
Exemple #2
0
def train(model,
          dataiter,
          dataiter_traintest,
          dataiter_test,
          ckpt_folder,
          final_save_folder=None,
          nshot_max=5,
          maxlen=40,
          logger=None,
          writer=None,
          in_stage=False,
          is_chief=True,
          reload_flag=None):
    """Trains the online few-shot model.
  Args:
    model: Model instance.
    dataiter: Dataset iterator.
    dataiter_test: Dataset iterator for validation.
    save_folder: Path to save the checkpoints.
  """
    N = model.max_train_steps
    config = model.config.train_config

    def try_log(*args, **kwargs):
        if logger is not None:
            logger.log(*args, **kwargs)

    def try_flush(*args, **kwargs):
        if logger is not None:
            logger.flush()

    r = None
    keep = True
    restart = False
    best_val = 0.0
    while keep:
        keep = False
        start = model.step.numpy()
        if start > 0:
            log.info('Restore from step {}'.format(start))

        it = six.moves.xrange(start, N)
        if is_chief:
            it = tqdm(it, ncols=0)
        for i, batch in zip(it, dataiter):
            tf.summary.experimental.set_step(i + 1)
            x = batch['x_s']
            y = batch['y_s']
            y = batch['y_s']
            y_gt = batch['y_gt']

            kwargs = {'y_gt': y_gt, 'flag': batch['flag_s']}
            kwargs['writer'] = writer
            loss = model.train_step(x, y, **kwargs)

            if i == start and reload_flag is not None and not restart:
                model.load(reload_flag, load_optimizer=True)

            # Synchronize distributed weights.
            if i == start and model._distributed and not restart:
                import horovod.tensorflow as hvd
                hvd.broadcast_variables(model.var_to_optimize(), root_rank=0)
                hvd.broadcast_variables(model.optimizer.variables(),
                                        root_rank=0)
                if model.config.set_backbone_lr:
                    hvd.broadcast_variables(model._bb_optimizer.variables(),
                                            root_rank=0)

            if loss > 10.0 and i > start + config.steps_per_save:
                # Something wrong happened.
                log.error('Something wrong happened. loss = {}'.format(loss))
                import pickle as pkl
                pkl.dump(
                    batch,
                    open(os.path.join(final_save_folder, 'debug.pkl'), 'wb'))
                log.error('debug file dumped')
                restart = True
                keep = True

                latest = latest_file(ckpt_folder, 'weights-')
                model.load(latest_file)
                log.error(
                    'Reloaded latest checkpoint from {}'.format(latest_file))
                break

            # Evaluate.
            if is_chief and ((i + 1) % config.steps_per_val == 0 or i == 0):
                for key, data_it_ in zip(['train', 'val'],
                                         [dataiter_traintest, dataiter_test]):
                    data_it_.reset()
                    r1 = evaluate(model, data_it_, 120)
                    r = get_stats(r1, nshot_max=nshot_max, tmax=maxlen)
                    for s in range(nshot_max):
                        try_log('online fs acc {}/s{}'.format(key, s), i + 1,
                                r['acc_nshot'][s] * 100.0)
                    try_log('online fs ap {}'.format(key), i + 1,
                            r['ap'] * 100.0)
                try_log('lr', i + 1, model.learn_rate())
                print()

            # Save.
            if is_chief and ((i + 1) % config.steps_per_save == 0 or i == 0):
                model.save(
                    os.path.join(ckpt_folder, 'weights-{}'.format(i + 1)))

                # Save the best checkpoint.
                if r is not None:
                    if r['ap'] > best_val:
                        model.save(
                            os.path.join(final_save_folder,
                                         'best-{}'.format(i + 1)))
                        best_val = r['ap']

            # Write logs.
            if is_chief and ((i + 1) % config.steps_per_log == 0 or i == 0):
                try_log('loss', i + 1, loss)
                try_flush()

                # Update progress bar.
                post_fix_dict = {}
                post_fix_dict['lr'] = '{:.3e}'.format(model.learn_rate())
                post_fix_dict['loss'] = '{:.3e}'.format(loss)
                if r is not None:
                    post_fix_dict['ap_val'] = '{:.3f}'.format(r['ap'] * 100.0)
                it.set_postfix(**post_fix_dict)

    # Save.
    if is_chief and final_save_folder is not None:
        model.save(os.path.join(final_save_folder, 'weights-{}'.format(N)))
Exemple #3
0
def main():
    assert tf.executing_eagerly(), 'Only eager mode is supported.'
    assert args.config is not None, 'You need to pass in model config file path'
    assert args.data is not None, 'You need to pass in episode config file path'
    assert args.env is not None, 'You need to pass in environ config file path'
    assert args.tag is not None, 'You need to specify a tag'

    log.info('Command line args {}'.format(args))
    config = get_config(args.config, ExperimentConfig)
    data_config = get_config(args.data, EpisodeConfig)
    env_config = get_config(args.env, EnvironmentConfig)
    log.info('Model: \n{}'.format(config))
    log.info('Data episode: \n{}'.format(data_config))
    log.info('Environment: \n{}'.format(env_config))
    config.num_classes = data_config.nway  # Assign num classes.
    config.num_steps = data_config.maxlen
    config.memory_net_config.max_classes = data_config.nway
    config.memory_net_config.max_stages = data_config.nstage
    config.memory_net_config.max_items = data_config.maxlen
    config.oml_config.num_classes = data_config.nway
    config.fix_unknown = data_config.fix_unknown  # Assign fix unknown ID.
    log.info('Number of classes {}'.format(data_config.nway))

    if 'SLURM_JOB_ID' in os.environ:
        log.info('SLURM job ID: {}'.format(os.environ['SLURM_JOB_ID']))

    # Create save folder.
    save_folder = os.path.join(env_config.results, env_config.dataset,
                               args.tag)

    if not args.reeval:
        model = build_pretrain_net(config)
        mem_model = build_net(config, backbone=model.backbone)
        reload_flag = None
        restore_steps = 0
        # Checkpoint folder.
        ckpt_path = env_config.checkpoint
        if len(ckpt_path) > 0 and os.path.exists(ckpt_path):
            ckpt_folder = os.path.join(ckpt_path, os.environ['SLURM_JOB_ID'])
        else:
            ckpt_folder = save_folder

        # Reload previous checkpoint.
        if os.path.exists(ckpt_folder) and not args.eval:
            latest = latest_file(ckpt_folder, 'weights-')
            if latest is not None:
                log.info('Checkpoint already exists. Loading from {}'.format(
                    latest))
                mem_model.load(latest)  # Not loading optimizer weights here.
                reload_flag = latest
                restore_steps = int(reload_flag.split('-')[-1])

        if not args.eval:
            save_config(config, save_folder)

        # Create TB logger.
        if not args.eval:
            writer = tf.summary.create_file_writer(save_folder)
            logger = ExperimentLogger(writer)

        # Get dataset.
        dataset = get_data_fs(env_config, load_train=True)

        # Get data iterators.
        if env_config.dataset in ["matterport"]:
            data = get_dataiter_sim(
                dataset,
                data_config,
                batch_size=config.optimizer_config.batch_size,
                nchw=mem_model.backbone.config.data_format == 'NCHW',
                seed=args.seed + restore_steps)
        else:
            data = get_dataiter_continual(
                dataset,
                data_config,
                batch_size=config.optimizer_config.batch_size,
                nchw=mem_model.backbone.config.data_format == 'NCHW',
                save_additional_info=True,
                random_box=data_config.random_box,
                seed=args.seed + restore_steps)

    # Load model, training loop.
    if not args.eval:
        if args.pretrain is not None and reload_flag is None:
            model.load(latest_file(args.pretrain, 'weights-'))
            if config.freeze_backbone:
                model.set_trainable(False)  # Freeze the network.
                log.info('Backbone network is now frozen')
        with writer.as_default() if writer is not None else dummy_context_mgr(
        ) as gs:
            train(mem_model,
                  data['train_fs'],
                  data['trainval_fs'],
                  data['val_fs'],
                  ckpt_folder,
                  final_save_folder=save_folder,
                  maxlen=data_config.maxlen,
                  logger=logger,
                  writer=writer,
                  in_stage=config.in_stage,
                  reload_flag=reload_flag)
    else:
        results_file = os.path.join(save_folder, 'results.pkl')
        logfile = os.path.join(save_folder, 'results.csv')
        if os.path.exists(results_file) and args.reeval:
            # Re-display results.
            results_all = pkl.load(open(results_file, 'rb'))
            for split, name in zip(['trainval_fs', 'val_fs', 'test_fs'],
                                   ['Train', 'Val', 'Test']):
                stats = get_stats(results_all[split], tmax=data_config.maxlen)
                log_results(stats, prefix=name, filename=logfile)
        else:
            # Load the most recent checkpoint.
            if args.usebest:
                latest = latest_file(save_folder, 'best-')
            else:
                latest = latest_file(save_folder, 'weights-')

            if latest is not None:
                mem_model.load(latest)
            else:
                latest = latest_file(args.pretrain, 'weights-')
                if latest is not None:
                    mem_model.load(latest)
                else:
                    raise ValueError('Checkpoint not found')
            data['trainval_fs'].reset()
            data['val_fs'].reset()
            data['test_fs'].reset()

            results_all = {}
            if args.testonly:
                split_list = ['test_fs']
                name_list = ['Test']
                nepisode_list = [config.num_episodes]
            else:
                split_list = ['trainval_fs', 'val_fs', 'test_fs']
                name_list = ['Train', 'Val', 'Test']
                nepisode_list = [600, config.num_episodes, config.num_episodes]

            for split, name, N in zip(split_list, name_list, nepisode_list):
                data[split].reset()
                r1 = evaluate(mem_model, data[split], N, verbose=True)
                stats = get_stats(r1, tmax=data_config.maxlen)
                log_results(stats, prefix=name, filename=logfile)
                results_all[split] = r1
            pkl.dump(results_all, open(results_file, 'wb'))
Exemple #4
0
def test_load_one(folder, seed=0):
    # log.info('Command line args {}'.format(args))
    config_file = os.path.join(folder, 'config.prototxt')
    config = get_config(config_file, ExperimentConfig)
    # config.c4_config.data_format = 'NHWC'
    # config.resnet_config.data_format = 'NHWC'
    if 'omniglot' in folder:
        if 'ssl' in folder:
            data_config_file = 'configs/episodes/roaming-omniglot/roaming-omniglot-150-ssl.prototxt'  # NOQA
        else:
            data_config_file = 'configs/episodes/roaming-omniglot/roaming-omniglot-150.prototxt'  # NOQA
        env_config_file = 'configs/environ/roaming-omniglot-docker.prototxt'
    elif 'rooms' in folder:
        if 'ssl' in folder:
            data_config_file = 'configs/episodes/roaming-rooms/roaming-rooms-100.prototxt'  # NOQA
        else:
            data_config_file = 'configs/episodes/roaming-rooms/romaing-rooms-100.prototxt'  # NOQA
        env_config_file = 'configs/environ/roaming-rooms-docker.prototxt'
    data_config = get_config(data_config_file, EpisodeConfig)
    env_config = get_config(env_config_file, EnvironmentConfig)
    log.info('Model: \n{}'.format(config))
    log.info('Data episode: \n{}'.format(data_config))
    log.info('Environment: \n{}'.format(env_config))
    config.num_classes = data_config.nway  # Assign num classes.
    config.num_steps = data_config.maxlen
    config.memory_net_config.max_classes = data_config.nway
    config.memory_net_config.max_stages = data_config.nstage
    config.memory_net_config.max_items = data_config.maxlen
    config.oml_config.num_classes = data_config.nway
    config.fix_unknown = data_config.fix_unknown  # Assign fix unknown ID.
    log.info('Number of classes {}'.format(data_config.nway))

    model = build_pretrain_net(config)
    mem_model = build_net(config, backbone=model.backbone)
    reload_flag = None
    restore_steps = 0

    # Reload previous checkpoint.
    latest = latest_file(folder, 'best-')
    if latest is None:
        latest = latest_file(folder, 'weights-')
    assert latest is not None, "Checkpoint not found."

    if latest is not None:
        log.info('Checkpoint already exists. Loading from {}'.format(latest))
        mem_model.load(latest)  # Not loading optimizer weights here.
        reload_flag = latest
        restore_steps = int(reload_flag.split('-')[-1])

    # Get dataset.
    dataset = get_data_fs(env_config, load_train=True)

    # Get data iterators.
    if env_config.dataset in ["roaming-rooms", "matterport"]:
        data = get_dataiter_sim(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=mem_model.backbone.config.data_format == 'NCHW',
            seed=seed + restore_steps)
    else:
        data = get_dataiter_continual(
            dataset,
            data_config,
            batch_size=config.optimizer_config.batch_size,
            nchw=mem_model.backbone.config.data_format == 'NCHW',
            save_additional_info=True,
            random_box=data_config.random_box,
            seed=seed + restore_steps)

    # Load the most recent checkpoint.
    latest = latest_file(folder, 'best-')
    if latest is None:
        latest = latest_file(folder, 'weights-')

    data['trainval_fs'].reset()
    data['val_fs'].reset()
    data['test_fs'].reset()

    results_all = {}
    split_list = ['trainval_fs', 'val_fs', 'test_fs']
    name_list = ['Train', 'Val', 'Test']
    nepisode_list = [5, 5, 5]
    # nepisode_list = [600, config.num_episodes, config.num_episodes]

    for split, name, N in zip(split_list, name_list, nepisode_list):
        # print(name)
        data[split].reset()
        r1 = evaluate(mem_model, data[split], N, verbose=False)
        stats = get_stats(r1, tmax=data_config.maxlen)
        print(split, stats['ap'])