Ejemplo n.º 1
0
def main():
    dataset = DataScheduler(config).dataset
    model = Res16UNet34C(3, config['y_c'], config['semantic_model']['conv1_kernel_size']) \
     .to(config['device'])
    state_dict = torch.load(config['semantic_model']['path'])
    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    model.load_state_dict(state_dict)
    model.eval()

    for i in tqdm(range(len(dataset))):
        coords, feats, targets, scene_name = raw_getitem(dataset, i)
        coords_voxel = torch.floor(coords / config['voxel_size']).int().numpy()
        vcoords, vfeats, vtargets = dataset.quantize_data(
            coords, feats, targets)
        vox_idx = vox2point(vcoords, coords_voxel)

        vcoords, vfeats, = sparse_collate([vcoords], [vfeats])
        x = SparseTensor(vfeats, vcoords).to(config['device'])
        y = vtargets

        with torch.no_grad():
            y_hat = model(x)
        semantic_pred = y_hat.max(dim=1).indices[vox_idx]
        mask = ((semantic_pred != 0) & (semantic_pred != 1)).float().cpu()
        outfile = os.path.join(DATA_DIR, scene_name, scene_name + OUT_POSTFIX)
        preprocessed_data = torch.cat(
            [coords, feats, targets,
             mask.unsqueeze(dim=1)], dim=1)
        torch.save(preprocessed_data, outfile)
Ejemplo n.º 2
0
def train_model(config, model: NdpmModel, scheduler: DataScheduler,
                writer: SummaryWriter):
    for step, (x, y, t) in enumerate(scheduler):
        step += 1
        if isinstance(model, NdpmModel):
            print('\r[Step {:4}] STM: {:5}/{} | #Expert: {}'.format(
                step, len(model.ndpm.stm_x), config['stm_capacity'],
                len(model.ndpm.experts) - 1),
                  end='')
        else:
            print('\r[Step {:4}]'.format(step), end='')

        summarize = step % config['summary_step'] == 0
        summarize_experts = summarize and isinstance(model, NdpmModel)
        summarize_samples = summarize and config['summarize_samples']

        # learn the model
        model.learn(x, y, t, step)

        # Evaluate the model
        evaluatable = (not isinstance(model, NdpmModel)
                       or len(model.ndpm.experts) > 1)
        if evaluatable and step % config['eval_step'] == 0:
            scheduler.eval(model, writer, step, 'model')

        # Evaluate experts of the model's DPMoE
        if summarize_experts:
            writer.add_scalar('num_experts', len(model.ndpm.experts) - 1, step)

        # Summarize samples
        if summarize_samples:
            is_ndpm = isinstance(model, NdpmModel)
            comps = [e.g for e in model.ndpm.experts[1:]] \
                if is_ndpm else [model.component]
            if len(comps) == 0:
                continue
            grid_h, grid_w = config['sample_grid']
            total_samples = []
            # Sample from each expert
            for i, expert in enumerate(comps):
                with torch.no_grad():
                    samples = expert.sample(grid_h * grid_w)
                total_samples.append(samples)
                collage = _make_collage(samples, config, grid_h, grid_w)
                writer.add_image('samples/{}'.format(i + 1), collage, step)

            if is_ndpm:
                counts = model.ndpm.prior.counts[1:]
                expert_w = counts / counts.sum()
                num_samples = torch.distributions.multinomial.Multinomial(
                    grid_h * grid_w, probs=expert_w).sample().type(torch.int)
                to_collage = []
                for i, samples in enumerate(total_samples):
                    to_collage.append(samples[:num_samples[i]])
                to_collage = torch.cat(to_collage, dim=0)
                collage = _make_collage(to_collage, config, grid_h, grid_w)
                writer.add_image('samples/ndpm', collage, step)
Ejemplo n.º 3
0
def main():
    args = parser.parse_args()

    # Load config
    config_path = args.config
    episode_path = args.episode
    if args.resume_ckpt and not args.config:
        base_dir = os.path.dirname(os.path.dirname(args.resume_ckpt))
        config_path = os.path.join(base_dir, 'config.yaml')
        episode_path = os.path.join(base_dir, 'episode.yaml')
    config = yaml.load(open(config_path), Loader=yaml.FullLoader)
    episode = yaml.load(open(episode_path), Loader=yaml.FullLoader)
    config['data_schedule'] = episode

    # Override options
    for option in args.override.split('|'):
        if not option:
            continue
        address, value = option.split('=')
        keys = address.split('.')
        here = config
        for key in keys[:-1]:
            if key not in here:
                raise ValueError('{} is not defined in config file. '
                                 'Failed to override.'.format(address))
            here = here[key]
        if keys[-1] not in here:
            raise ValueError('{} is not defined in config file. '
                             'Failed to override.'.format(address))
        here[keys[-1]] = yaml.load(value, Loader=yaml.FullLoader)

    # Set log directory
    config['log_dir'] = args.log_dir
    if not args.resume_ckpt and os.path.exists(args.log_dir):
        print('WARNING: %s already exists' % args.log_dir)
        input('Press enter to continue')

    if args.resume_ckpt and not args.log_dir:
        config['log_dir'] = os.path.dirname(os.path.dirname(args.resume_ckpt))

    # Save config
    os.makedirs(config['log_dir'], mode=0o755, exist_ok=True)
    if not args.resume_ckpt or args.config:
        config_save_path = os.path.join(config['log_dir'], 'config.yaml')
        episode_save_path = os.path.join(config['log_dir'], 'episode.yaml')
        yaml.dump(config, open(config_save_path, 'w'))
        yaml.dump(episode, open(episode_save_path, 'w'))
        print('Config & episode saved to {}'.format(config['log_dir']))

    # Build components
    data_scheduler = DataScheduler(config)
    writer = SummaryWriter(config['log_dir'])
    model = MODEL[config['model_name']](config, writer)
    if args.resume_ckpt:
        model.load_state_dict(torch.load(args.resume_ckpt))
    model.to(config['device'])
    train_model(config, model, data_scheduler, writer)
Ejemplo n.º 4
0
def train_model(config, model: Model, scheduler: DataScheduler,
                writer: SummaryWriter):
    saved_model_path = os.path.join(config['log_dir'], 'ckpts')
    os.makedirs(saved_model_path, exist_ok=True)

    skip_batch = 0
    for step, (x, y, epoch) in enumerate(scheduler):

        x, y = x.to(config['device']), y.to(config['device'])

        # since number of points vary in the dataset,
        # we skip if gpu overflow occurs
        if config['skip_gpu_overflow']:
            try:
                train_loss = model.learn(x, y, step)
            except RuntimeError:
                skip_batch += 1
                continue
        else:
            train_loss = model.learn(x, y, step)

        # model learns
        print('\r[Epoch {:4}, Step {:7}, Overflow: {:7}, Loss {:5}]'.format(
            epoch, step, skip_batch, '%.3f' % train_loss),
              end='')

        # evaluate
        if scheduler.check_eval_step(step):
            scheduler.eval(model, writer, step)

        if scheduler.check_vis_step(step):
            print("\nVisualizing...")
            scheduler.visualize(model, writer, step)
            writer.add_scalar('skip_batch', skip_batch, step)

        if (step + 1) % config['ckpt_step'] == 0:
            torch.save(
                model.state_dict(),
                os.path.join(saved_model_path,
                             'ckpt-step-{}'.format(str(step + 1).zfill(3))))

        model.lr_scheduler.step()
Ejemplo n.º 5
0
Archivo: train.py Proyecto: yyht/PRS
def train_model(config, model: MLabReservoir, scheduler: DataScheduler,
                writer: SummaryWriter):
    saved_model_path = os.path.join(config['log_dir'], 'ckpts')

    os.makedirs(saved_model_path, exist_ok=True)

    prev_t = config['data_schedule'][0]['subsets'][0][1]
    done_t_num = 0

    results_dict = dict()
    for step, (x, y, t) in enumerate(scheduler):

        summarize = step % config['summary_step'] == 0
        # if we want to evaluate based on steps.
        evaluate = (step % config['eval_step'] == config['eval_step'] - 1)

        # find current task t's id in data_schedule to obtain the data name.
        for data_dict in config['data_schedule']:
            for subset in data_dict['subsets']:
                if subset[1] == t:
                    cur_subset = subset[0]

        # Evaluate the model when task changes
        if t != prev_t:
            done_t_num += 1
            results_dict = scheduler.eval(model,
                                          writer,
                                          step + 1,
                                          prev_t,
                                          eval_title='eval',
                                          results_dict=results_dict)
            # Save the model
            torch.save(
                model.state_dict(),
                os.path.join(saved_model_path,
                             'ckpt-{}'.format(str(step + 1).zfill(6))))

            print(
                colorful.bold_green('\nProgressing to Task %d' %
                                    t).styled_string)

        if step == 0:
            print(
                colorful.bold_green('\nProgressing to Task %d' %
                                    t).styled_string)

        if done_t_num >= len(scheduler.schedule):
            writer.flush()
            return

        # learn the model
        for i in range(config['batch_iter']):
            if 'slab' in config['model_name']:
                model.learn(x,
                            y,
                            t,
                            step * config['batch_iter'] + i,
                            scheduler.datasets[cur_subset].category_map,
                            scheduler.datasets[cur_subset].split_cats_dict,
                            data_obj=scheduler.datasets[cur_subset])
            else:
                model.learn(x,
                            y,
                            t,
                            step * config['batch_iter'] + i,
                            scheduler.datasets[cur_subset].category_map,
                            scheduler.datasets[cur_subset].split_cats_dict,
                            data_obj=scheduler.datasets[cur_subset].subsets[t])

        prev_t = t
Ejemplo n.º 6
0
def train_model(config: Dict, model: NdpmModel,
                scheduler: DataScheduler,
                writer: SummaryWriter):
    saved_model_path = os.path.join(config["log_dir"], "ckpts")
    os.makedirs(saved_model_path, exist_ok=True)

    is_ndmp = isinstance(model, NdpmModel)

    for step, (x, y, t) in enumerate(scheduler):
        step += 1
        if is_ndmp:
            stm_item_count = len(model.ndpm.stm_x)
            stm_capacity = config["stm_capacity"]
            print(
                f"\r[Step {step:4}]",
                f"STM: {stm_item_count:5}/{stm_capacity}",
                f"| #Expert: {len(model.ndpm.experts) - 1}",
                end="",
            )
        else:
            print("\r[Step {:4}]".format(step), end="")

        summarize = step % config["summary_step"] == 0
        summarize_experts = summarize and isinstance(model, NdpmModel)
        summarize_samples = summarize and config["summarize_samples"]

        # learn the model
        model.learn(x, y, t, step)

        # Evaluate the model
        evaluatable = (
            not is_ndmp or len(model.ndpm.experts) > 1
        )
        if evaluatable and step % config["eval_step"] == 0:
            scheduler.eval(model, writer, step, "model")

        if step % config["ckpt_step"] == 0:
            print("\nSaving checkpoint... ", end="")
            ckpt_path = os.path.join(saved_model_path,
                                     "ckpt-{}.pt".format(str(step).zfill(6)))
            del model.writer
            if is_ndmp:
                del model.ndpm.writer
            with open(ckpt_path, "wb") as f:
                pickle.dump(model, f)
            model.writer = writer
            if is_ndmp:
                model.ndpm.writer = writer
            print("Saved to {}".format(ckpt_path))

        # Evaluate experts of the model"s DPMoE
        if summarize_experts:
            writer.add_scalar("num_experts", len(model.ndpm.experts) - 1, step)

        # Summarize samples
        if summarize_samples:
            if is_ndmp:
                comps = [e.g for e in model.ndpm.experts[1:]]
            else:
                comps = [model.component]
            
            if len(comps) == 0:
                continue
            grid_h, grid_w = config["sample_grid"]
            total_samples = []
            # Sample from each expert
            for i, expert in enumerate(comps):
                with torch.no_grad():
                    samples = expert.sample(grid_h * grid_w)
                total_samples.append(samples)
                collage = _make_collage(samples, config, grid_h, grid_w)
                writer.add_image(f"samples/{i + 1}", collage, step)

            if is_ndmp:
                counts = model.ndpm.prior.counts[1:]
                expert_w = counts / counts.sum()
                num_samples = torch.distributions.multinomial.Multinomial(
                    grid_h * grid_w, probs=expert_w).sample().type(torch.int)
                to_collage = []
                for i, samples in enumerate(total_samples):
                    to_collage.append(samples[:num_samples[i]])
                to_collage = torch.cat(to_collage, dim=0)
                collage = _make_collage(to_collage, config, grid_h, grid_w)
                writer.add_image("samples/ndpm", collage, step)
Ejemplo n.º 7
0
Archivo: main.py Proyecto: yyht/PRS
def main():
    args = parser.parse_args()
    logger = setup_logger()

    ## Use below for slurm setting.
    # slurm_job_id = os.getenv('SLURM_JOB_ID', 'nojobid')
    # slurm_proc_id = os.getenv('SLURM_PROC_ID', None)

    # unique_identifier = str(slurm_job_id)
    # if slurm_proc_id is not None:
    #     unique_identifier = unique_identifier + "_" + str(slurm_proc_id)
    unique_identifier = ''

    # Load config
    config_path = args.config
    episode_path = args.episode

    if args.resume_ckpt and not args.config:
        base_dir = os.path.dirname(os.path.dirname(args.resume_ckpt))
        config_path = os.path.join(base_dir, 'config.yaml')
        episode_path = os.path.join(base_dir, 'episode.yaml')
    config = yaml.load(open(config_path), Loader=yaml.FullLoader)
    episode = yaml.load(open(episode_path), Loader=yaml.FullLoader)
    config['data_schedule'] = episode

    # Override options
    for option in args.override.split('|'):
        if not option:
            continue
        address, value = option.split('=')
        keys = address.split('.')
        here = config
        for key in keys[:-1]:
            if key not in here:
                raise ValueError('{} is not defined in config file. '
                                 'Failed to override.'.format(address))
            here = here[key]
        if keys[-1] not in here:
            raise ValueError('{} is not defined in config file. '
                             'Failed to override.'.format(address))
        here[keys[-1]] = yaml.load(value, Loader=yaml.FullLoader)


    # Set log directory
    config['log_dir'] = os.path.join(args.log_dir, unique_identifier)
    if not args.resume_ckpt and os.path.exists(config['log_dir']):
        logger.warning('%s already exists' % config['log_dir'])
        input('Press enter to continue')

    # print the configuration
    print(colorful.bold_white("configuration:").styled_string)
    pprint(config)
    print(colorful.bold_white("configuration end").styled_string)

    if args.resume_ckpt and not args.log_dir:
        config['log_dir'] = os.path.dirname(
            os.path.dirname(args.resume_ckpt)
        )

    # Save config
    os.makedirs(config['log_dir'], mode=0o755, exist_ok=True)
    if not args.resume_ckpt or args.config:
        config_save_path = os.path.join(config['log_dir'], 'config.yaml')
        episode_save_path = os.path.join(config['log_dir'], 'episode.yaml')
        yaml.dump(config, open(config_save_path, 'w'))
        yaml.dump(episode, open(episode_save_path, 'w'))
        print(colorful.bold_yellow('config & episode saved to {}'.format(config['log_dir'])).styled_string)

    # Build components
    data_scheduler = DataScheduler(config)

    writer = SummaryWriter(config['log_dir'])
    model = MODEL[config['model_name']](config, writer)

    if args.resume_ckpt:
        model.load_state_dict(torch.load(args.resume_ckpt))
    model.to(config['device'])
    train_model(config, model, data_scheduler, writer)

    print(colorful.bold_white("\nThank you and Good Job Computer").styled_string)