Example #1
0
    def step(self, feed_dict, reduce_func=default_reduce_func):
        assert self._model.training, 'Step a evaluation-mode model.'

        self.trigger_event('step:before', self)

        feed_dict = as_variable(feed_dict)

        begin = time.time()

        self.trigger_event('forward:before', self, feed_dict)
        loss, monitors, output_dict = self._model(feed_dict)
        self.trigger_event('forward:after', self, feed_dict, loss, monitors,
                           output_dict)

        loss = reduce_func('loss', loss)
        monitors = {k: reduce_func(k, v) for k, v in monitors.items()}

        loss_f = as_float(loss)
        monitors_f = as_float(monitors)

        self._optimizer.zero_grad()
        self.trigger_event('backward:before', self, loss)
        loss.backward()
        self.trigger_event('backward:after', self, loss)
        self._optimizer.step()

        end = time.time()

        self.trigger_event('step:after', self)

        return loss_f, monitors_f, output_dict, {'time/gpu': end - begin}
Example #2
0
    def step(self,
             feed_dict,
             reduce_func=default_reduce_func,
             cast_tensor=False,
             measure_time=False):
        if hasattr(self.model, 'train_step'):
            return self.model.train_step(self.optimizer, feed_dict)

        assert self._model.training, 'Step a evaluation-mode model.'
        extra = dict()

        self.trigger_event('step:before', self)

        if cast_tensor:
            feed_dict = as_tensor(feed_dict)

        if measure_time:
            end_time = cuda_time()

        self.trigger_event('forward:before', self, feed_dict)
        loss, monitors, output_dict = self._model(feed_dict)
        self.trigger_event('forward:after', self, feed_dict, loss, monitors,
                           output_dict)

        if measure_time:
            extra['time/forward'] = cuda_time() - end_time
            end_time = cuda_time(False)

        loss = reduce_func('loss', loss)
        monitors = {k: reduce_func(k, v) for k, v in monitors.items()}

        loss_f = as_float(loss)
        monitors_f = as_float(monitors)

        if measure_time:
            extra['time/loss'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self._optimizer.zero_grad()
        self.trigger_event('backward:before', self, feed_dict, loss, monitors,
                           output_dict)
        if loss.requires_grad:
            loss.backward()

        if measure_time:
            extra['time/backward'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self.trigger_event('backward:after', self, feed_dict, loss, monitors,
                           output_dict)
        if loss.requires_grad:
            self._optimizer.step()

        if measure_time:
            extra['time/optimize'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self.trigger_event('step:after', self)

        return loss_f, monitors_f, output_dict, extra
Example #3
0
 def test_jac_dataloader(self):
     ds = _FakeDataset()
     dl = JacDataLoader(ds,
                        num_workers=2,
                        worker_init_fn=_my_init_func,
                        worker_init_args=[('hello', ), ('world', )])
     res = list(dl)
     self.assertNotEqual(as_float(res[0]), as_float(res[1]))
Example #4
0
def step_epoch(
    model,
    loader,
    criterion,
    optimizer,
    epoch,
):

    non_blocking = True
    total_loss = 0.
    total_acc_qa = 0.

    pbar = tqdm(loader)
    for i, feed_dict in enumerate(pbar):
        feed_dict['image'] = feed_dict['image'].to(device,
                                                   non_blocking=non_blocking)
        feed_dict['objects'] = feed_dict['objects'].to(
            device, non_blocking=non_blocking)
        feed_dict['objects_length'] = feed_dict['objects_length'].to(
            device, non_blocking=non_blocking)
        feed_dict['questions'] = feed_dict['questions'].to(
            device, non_blocking=non_blocking)
        feed_dict['answer'] = feed_dict['answer'].to(device,
                                                     non_blocking=non_blocking)

        programs, buffers, answers = model(feed_dict)
        loss = criterion(feed_dict, answers)

        monitors = {}
        outputs = {
            'buffers': buffers,
            'answer': answers,
        }
        update_from_loss_module(monitors, outputs, loss)
        canonize_monitors(monitors)

        loss = monitors['loss/qa']

        loss = reduce_func('loss', loss)
        monitors = {k: reduce_func(k, v) for k, v in monitors.items()}

        loss_f = as_float(loss)
        monitors_f = as_float(monitors)

        total_loss += loss_f
        total_acc_qa += monitors_f['acc/qa']

        optimizer.zero_grad()
        optimizer.step()

        pbar.set_postfix(
            loss=f'{loss_f:.4f} ({total_loss/(i + 1):.4f})',
            acc_qa=f'{monitors_f["acc/qa"]} ({total_acc_qa/(i + 1):.4f})',
        )
        pbar.update()

    return total_loss / (i + 1), total_loss / (i + 1)
Example #5
0
def regression_accuracy(pred, label, name=''):
    if name != '':
        name = '/' + name
    prefix = 'accuracy' + name
    pred = pred.view(-1)  # Binary accuracy
    label = label.view(-1)
    diff = pred - label
    return {
        prefix + '/l1': as_float(diff.abs().mean()),
        prefix + '/l2': as_float(0.5 * diff.pow(2).mean())
    }
Example #6
0
def binary_classification_accuracy(pred, label, name='', saturation=True):
    if name != '':
        name = '/' + name
    prefix = 'accuracy' + name
    pred = pred.view(-1)  # Binary accuracy
    label = label.view(-1)
    acc = label.float().eq((pred > 0.5).float())
    if saturation:
        sat = 1 - (pred - (pred > 0.5).float()).abs()
        return {
            prefix: as_float(acc.float().mean()),
            prefix + '/saturation/mean': as_float(sat.mean()),
            prefix + '/saturation/min': as_float(sat.min())
        }
    return {prefix: as_float(acc.float().mean())}
Example #7
0
def validate_epoch(model, val_dataloader, meters):
    end = time.time()
    with tqdm_pbar(total=len(val_dataloader)) as pbar:
        for feed_dict in val_dataloader:
            if args.use_gpu:
                if not args.gpu_parallel:
                    feed_dict = async_copy_to(feed_dict, 0)

            data_time = time.time() - end
            end = time.time()

            output_dict = model(feed_dict)

            # TODO(Jiayuan Mao @ 04/26): compute the monitoring values.
            monitors = as_float(output['monitors'])
            step_time = time.time() - end
            end = time.time()

            # TODO(Jiayuan Mao @ 04/23): normalize the loss/other metrics by adding n=xxx if applicable.
            meters.update(monitors)
            meters.update({'time/data': data_time, 'time/step': step_time})

            if args.use_tb:
                meters.flush()

            pbar.set_description(
                meters.format_simple('Test', 'val', compressed=True))
            pbar.update()

            end = time.time()
Example #8
0
def validate_epoch(epoch, trainer, val_dataloader, meters):
    end = time.time()
    with tqdm_pbar(total=len(val_dataloader)) as pbar:
        for feed_dict in val_dataloader:
            if args.use_gpu:
                if not args.gpu_parallel:
                    feed_dict = async_copy_to(feed_dict, 0)

            data_time = time.time() - end; end = time.time()

            output_dict, extra_info = trainer.evaluate(feed_dict)

            # TODO(Jiayuan Mao @ 04/26): compute the monitoring values.
            monitors = as_float(output_dict['monitors'])
            step_time = time.time() - end; end = time.time()

            # TODO(Jiayuan Mao @ 04/23): normalize the loss/other metrics by adding n=xxx if applicable.
            meters.update(monitors)
            meters.update({'time/data': data_time, 'time/step': step_time})

            if args.use_tb:
                meters.flush()

            pbar.set_description(meters.format_simple(
                'Epoch {} (validation)'.format(epoch),
                {k: v for k, v in meters.val.items() if k.startswith('validation') and k.count('/') <= 2},
                compressed=True
            ), refresh=False)
            pbar.update()

            end = time.time()
Example #9
0
def validate_epoch(epoch, trainer, val_dataloader, meters, meter_prefix='validation'):
    end = time.time()
    with tqdm_pbar(total=len(val_dataloader)) as pbar:
        for feed_dict in val_dataloader:
            if args.use_gpu:
                if not args.gpu_parallel:
                    feed_dict = async_copy_to(feed_dict, 0)

            data_time = time.time() - end; end = time.time()

            output_dict, extra_info = trainer.evaluate(feed_dict, cast_tensor=False)
            monitors = {meter_prefix + '/' + k: v for k, v in as_float(output_dict['monitors']).items()}
            step_time = time.time() - end; end = time.time()

            n = feed_dict['image'].size(0)
            meters.update(monitors, n=n)
            meters.update({'time/data': data_time, 'time/step': step_time})

            if args.use_tb:
                meters.flush()

            pbar.set_description(meters.format_simple(
                'Epoch {} (validation)'.format(epoch),
                {k: v for k, v in meters.val.items() if k.startswith('validation') and k.count('/') <= 2},
                compressed=True
            ))
            pbar.update()

            end = time.time()
Example #10
0
def classification_accuracy(pred, label, name=''):
    if name != '':
        name = '/' + name
    prefix = 'accuracy' + name
    pred = pred.view(-1)  # Binary accuracy
    label = label.view(-1)
    acc = label.float().eq((pred).float())
    return {prefix: as_float(acc.float().mean())}
Example #11
0
 def validate_step(self, feed_dict, metric, meters=None):
     feed_dict_np = as_numpy(feed_dict)
     feed_dict = mark_volatile(as_variable(feed_dict))
     output_dict = self._model(feed_dict)
     output_dict_np = as_numpy(output_dict)
     result = as_float(metric(feed_dict_np, output_dict_np))
     if meters is not None:
         meters.update(result)
     return result
Example #12
0
 def validate_step(self, feed_dict, metric, meters=None):
     feed_dict_np = as_numpy(feed_dict)
     feed_dict = as_tensor(feed_dict)
     with torch.no_grad():
         output_dict = self._model(feed_dict)
     output_dict_np = as_numpy(output_dict)
     result = as_float(metric(feed_dict_np, output_dict_np))
     if meters is not None:
         meters.update(result)
     return result
Example #13
0
    def update(self, feed_dict, loss, monitors, output_dict, grad_clip=0., reduce_func=default_reduce_func, measure_time=False, extra=None):
        assert self.__prepared, 'Two consecutive call of TrainerEnv.update()'
        self.__prepared = False

        if extra is None:
            extra = dict()

        loss = reduce_func('loss', loss)
        monitors = {k: reduce_func(k, v) for k, v in monitors.items()}

        loss_f = as_float(loss)
        monitors_f = as_float(monitors)

        if measure_time:
            extra['time/loss'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self.trigger_event('backward:before', self, feed_dict, loss, monitors, output_dict)
        if loss.requires_grad:
            loss.backward()
            if grad_clip > 0:
                from torch.nn.utils.clip_grad import clip_grad_norm_
                clip_grad_norm_(self.model.parameters(), grad_clip)

        if measure_time:
            extra['time/backward'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self.trigger_event('backward:after', self, feed_dict, loss, monitors, output_dict)
        if loss.requires_grad:
            self._optimizer.step()

        if measure_time:
            extra['time/optimize'] = cuda_time() - end_time
            end_time = cuda_time(False)

        self.trigger_event('step:after', self)
        return loss_f, monitors_f, output_dict, extra
Example #14
0
    def train_step(self, feed_dict, meters=None):
        assert self._model.training
        feed_dict = as_tensor(feed_dict)

        self._optimizer.zero_grad()
        loss, monitors, output_dict = self._model(feed_dict)
        loss.backward()
        self._optimizer.step()

        loss, monitors = map(as_float, [loss, monitors])
        if meters is not None:
            meters.update(loss=loss)
            meters.update(monitors)

        return as_float(loss)
Example #15
0
def binary_accuracy(label, raw_pred, eps=1e-20, return_float=True):
    """get accuracy for binary classification problem."""
    pred = as_tensor(raw_pred).squeeze(-1)
    pred = (pred > 0.5).float()
    label = as_tensor(label).float()
    # The $acc is micro accuracy = the correct ones / total
    acc = label.eq(pred).float()

    # The $balanced_accuracy is macro accuracy, with class-wide balance.
    nr_total = torch.ones(label.size(), dtype=label.dtype,
                          device=label.device).sum(dim=-1)
    nr_pos = label.sum(dim=-1)
    nr_neg = nr_total - nr_pos
    pos_cnt = (acc * label).sum(dim=-1)
    neg_cnt = acc.sum(dim=-1) - pos_cnt
    balanced_acc = ((pos_cnt + eps) / (nr_pos + eps) + (neg_cnt + eps) /
                    (nr_neg + eps)) / 2.0

    # $sat means the saturation rate of the predication,
    # measure how close the predections are to 0 or 1.
    sat = 1 - (raw_pred - pred).abs()
    if return_float:
        acc = as_float(acc.mean())
        balanced_acc = as_float(balanced_acc.mean())
        sat_mean = as_float(sat.mean())
        sat_min = as_float(sat.min())
    else:
        sat_mean = sat.mean(dim=-1)
        sat_min = sat.min(dim=-1)[0]

    return {
        'accuracy': acc,
        'balanced_accuracy': balanced_acc,
        'saturation/mean': sat_mean,
        'saturation/min': sat_min,
    }
def validate_epoch(epoch, trainer, val_dataloader, meters, meter_prefix='validation'):
    if args.testing_flag:
        json_output_list = []
    
    end = time.time()
    with tqdm_pbar(total=len(val_dataloader)*args.batch_size) as pbar:
        for feed_dict in val_dataloader:
            if args.use_gpu:
                if not args.gpu_parallel:
                    feed_dict = async_copy_to(feed_dict, 0)
            #pdb.set_trace()
            data_time = time.time() - end; end = time.time()
            output_dict_list, extra_info = trainer.evaluate(feed_dict, cast_tensor=False)
            if args.testing_flag:
                prepare_data_for_testing(output_dict_list, feed_dict, json_output_list)

            step_time = time.time() - end; end = time.time()
            for idx, mon_dict  in enumerate(output_dict_list['monitors']): 
                monitors = {meter_prefix + '/' + k: v for k, v in as_float(mon_dict).items()}
                # remove padding values
                for tmp_key, tmp_value in monitors.items(): 
                    if isinstance(tmp_value , list):
                        for sub_idx, sub_value in enumerate(tmp_value):
                            if sub_value[0]==-1:
                                continue 
                            meters.update({tmp_key: sub_value[0]}, n=sub_value[1])
                    elif tmp_value==-1:
                        continue 
                    else:
                        meters.update({tmp_key: tmp_value}, n=1)
                
                meters.update({'time/data': data_time, 'time/step': step_time})
                if args.use_tb:
                    meters.flush()

                pbar.set_description(meters.format_simple(
                    'Epoch {} (validation)'.format(epoch),
                    {k: v for k, v in meters.val.items() if k.startswith('validation') and k.count('/') <= 2},
                    compressed=True
                ))
                pbar.update()

            end = time.time()
    if args.testing_flag==1:
        jsondump(args.test_result_path, json_output_list)
Example #17
0
def validate_epoch(epoch,
                   trainer,
                   val_dataloader,
                   meters,
                   meter_prefix="validation"):
    end = time.time()
    with tqdm_pbar(total=len(val_dataloader)) as pbar:
        for feed_dict in val_dataloader:
            if args.use_gpu:
                if not args.gpu_parallel:
                    feed_dict = async_copy_to(feed_dict, 0)

            data_time = time.time() - end
            end = time.time()

            output_dict, extra_info = trainer.evaluate(feed_dict,
                                                       cast_tensor=False)
            monitors = {
                meter_prefix + "/" + k: v
                for k, v in as_float(output_dict["monitors"]).items()
            }
            step_time = time.time() - end
            end = time.time()

            n = feed_dict["image"].size(0)
            meters.update(monitors, n=n)
            meters.update({"time/data": data_time, "time/step": step_time})

            if args.use_tb:
                meters.flush()

            pbar.set_description(
                meters.format_simple(
                    "Epoch {} (validation)".format(epoch),
                    {
                        k: v
                        for k, v in meters.val.items()
                        if (k.startswith(meter_prefix)) and k.count("/") <= 2
                    },
                    compressed=True,
                ))
            pbar.update()

            end = time.time()
Example #18
0
def rms(p):
    """Root mean square function."""
    return as_float((as_tensor(p)**2).mean()**0.5)
Example #19
0
 def test_torch_dataloader(self):
     ds = _FakeDataset()
     dl = DataLoader(ds, num_workers=2)
     res = list(dl)
     self.assertEqual(as_float(res[0]), as_float(res[1]))
Example #20
0
def validate_epoch(epoch,
                   model,
                   val_dataloader,
                   meters,
                   meter_prefix='validation'):
    end = time.time()

    visualized = 0
    vis = HTMLTableVisualizer(args.vis_dir, 'NSCL Execution Visualization')
    vis.begin_html()

    try:
        with tqdm_pbar(total=len(val_dataloader)) as pbar:
            for feed_dict in val_dataloader:
                if args.use_gpu:
                    if not args.gpu_parallel:
                        feed_dict = async_copy_to(feed_dict, 0)

                data_time = time.time() - end
                end = time.time()

                output_dict = model(feed_dict)
                monitors = {
                    meter_prefix + '/' + k: v
                    for k, v in as_float(output_dict['monitors']).items()
                }
                step_time = time.time() - end
                end = time.time()

                n = feed_dict['image'].size(0)
                meters.update(monitors, n=n)
                meters.update({'time/data': data_time, 'time/step': step_time})

                feed_dict = GView(as_detached(as_cpu(feed_dict)))
                output_dict = GView(as_detached(as_cpu(output_dict)))

                for i in range(n):
                    with vis.table(
                            'Visualize #{} Metainfo'.format(visualized), [
                                HTMLTableColumnDesc('id', 'QID', 'text',
                                                    {'width': '50px'}),
                                HTMLTableColumnDesc('image', 'Image', 'figure',
                                                    {'width': '400px'}),
                                HTMLTableColumnDesc('qa', 'QA', 'text',
                                                    {'width': '200px'}),
                                HTMLTableColumnDesc('p', 'Program', 'code',
                                                    {'width': '200px'})
                            ]):
                        image_filename = osp.join(args.data_image_root,
                                                  feed_dict.image_filename[i])
                        image = Image.open(image_filename)
                        fig, ax = vis_bboxes(image,
                                             feed_dict.objects_raw[i],
                                             'object',
                                             add_text=False)
                        _ = ax.set_title('object bounding box annotations')
                        QA_string = """
                            <p><b>Q</b>: {}</p>
                            <p><b>A</b>: {}</p>
                        """.format(feed_dict.question_raw[i],
                                   feed_dict.answer[i])
                        P_string = '\n'.join(
                            [repr(x) for x in feed_dict.program_seq[i]])

                        vis.row(id=i, image=fig, qa=QA_string, p=P_string)
                        plt.close()

                    with vis.table(
                            'Visualize #{} Metainfo'.format(visualized), [
                                HTMLTableColumnDesc('id', 'QID', 'text',
                                                    {'width': '50px'}),
                                HTMLTableColumnDesc('image', 'Image', 'figure',
                                                    {'width': '400px'}),
                                HTMLTableColumnDesc('mask', 'Mask', 'figure',
                                                    {'width': '700px'})
                            ]):
                        image_filename = osp.join(args.data_image_root,
                                                  feed_dict.image_filename[i])
                        image = Image.open(image_filename)
                        fig, ax = vis_bboxes(image,
                                             feed_dict.objects_raw[i],
                                             'object',
                                             add_text=False)
                        _ = ax.set_title('object bounding box annotations')
                        if not args.show_mask:
                            montage = fig
                        else:
                            num_slots = output_dict['monet/m'].shape[1]
                            monet_fig = [
                                [
                                    tensor2im(output_dict['monet/m'][i, k])
                                    for k in range(num_slots)
                                ],
                                [
                                    tensor2im(output_dict['monet/x'][i, k])
                                    for k in range(num_slots)
                                ],
                                [
                                    tensor2im(output_dict['monet/xm'][i, k])
                                    for k in range(num_slots)
                                ],
                                [tensor2im(output_dict['monet/x_input'][i])] +
                                [
                                    tensor2im(output_dict['monet/x_tilde'][i])
                                    for k in range(num_slots - 1)
                                ]
                            ]
                            montage = montage_fig(monet_fig)
                        vis.row(id=i, image=fig, mask=montage)
                        plt.close()

                    with vis.table('Visualize #{} Trace'.format(visualized), [
                            HTMLTableColumnDesc('id', 'Step', 'text',
                                                {'width': '50px'}),
                            HTMLTableColumnDesc('image', 'Image', 'figure',
                                                {'width': '600px'}),
                            HTMLTableColumnDesc('p', 'operation', 'text',
                                                {'width': '200px'}),
                            HTMLTableColumnDesc('r', 'result', 'code',
                                                {'width': '200px'})
                    ]):
                        # TODO(Jiayuan Mao @ 11/20): support output_dict.programs.
                        for j, (prog, buf) in enumerate(
                                zip(feed_dict.program_seq[i],
                                    output_dict.buffers[i])):
                            if j != len(feed_dict.program_seq[i]) - 1 and (
                                    buf > 0
                            ).long().sum().item() > 0 and buf.size(
                                    0) == feed_dict.objects_raw[i].shape[0]:
                                this_objects = feed_dict.objects_raw[i][
                                    torch.nonzero(buf > 0)[:, 0].numpy()]
                                fig, ax = vis_bboxes(image,
                                                     this_objects,
                                                     'object',
                                                     add_text=False)
                            else:
                                fig, ax = vis_bboxes(image, [],
                                                     'object',
                                                     add_text=False)
                            vis.row(id=j, image=fig, p=repr(prog), r=repr(buf))
                            plt.close()

                    visualized += 1
                    if visualized > args.nr_visualize:
                        raise StopIteration()

                pbar.set_description(
                    meters.format_simple(
                        'Epoch {} (validation)'.format(epoch), {
                            k: v
                            for k, v in meters.val.items()
                            if k.startswith('validation') and k.count('/') <= 1
                        },
                        compressed=True))
                pbar.update()

                end = time.time()
    except StopIteration:
        pass

    from jacinle.utils.meta import dict_deep_kv
    from jacinle.utils.printing import kvformat
    with vis.table('Info', [
            HTMLTableColumnDesc('name', 'Name', 'code', {}),
            HTMLTableColumnDesc('info', 'KV', 'code', {})
    ]):
        vis.row(name='args', info=kvformat(args.__dict__, max_key_len=32))
        vis.row(name='configs',
                info=kvformat(dict(dict_deep_kv(configs)), max_key_len=32))
    vis.end_html()

    logger.info(
        'Happy Holiday! You can find your result at "http://monday.csail.mit.edu/xiuming'
        + osp.realpath(args.vis_dir) + '".')
Example #21
0
    def step(self,
             feed_dict,
             reduce_func=default_reduce_func,
             cast_tensor=False):
        assert self._model.training, 'Step a evaluation-mode model.'
        self.num_iters += 1
        self.trigger_event('step:before', self)
        loss_latent = 0.0
        if cast_tensor:
            feed_dict = as_tensor(feed_dict)

        begin = time.time()

        self.trigger_event('forward:before', self, feed_dict)

        rl_loss = 0.0
        if self.mode == 'warmup':
            loss, monitors, output_dict = self._model(feed_dict)
        else:
            if args.no_static:
                loss, monitors, output_dict = self._model(
                    feed_dict, return_loss_matrix=True)
                y_hat = output_dict['pred'].detach()
            else:
                with torch.no_grad():
                    #y_hat = self._static_model(feed_dict)['pred'].detach()
                    static_model_output = self._static_model(
                        feed_dict, return_loss_matrix=True)
                    if isinstance(static_model_output, dict):
                        y_hat = static_model_output['pred'].detach()
                        output_dict = static_model_output
                    else:
                        y_hat = static_model_output[2]['pred'].detach()
                        output_dict = static_model_output[2]

            keys = [
                'mask', 'n', 'query', 'count', 'is_ambiguous', 'qid',
                'target_set', 'relations', 'gtlt'
            ]

            expanded_feed_dict = {}
            for key in keys:
                if key in feed_dict:
                    expanded_feed_dict[key] = expand_tensor(
                        feed_dict[key], feed_dict["count"])
            #
            #unravel target set to obtain different targets
            expanded_feed_dict["target"] = unravel_tensor(
                feed_dict["target_set"], feed_dict["count"])
            # copy interemediate y for each target
            y_hat = expand_tensor(y_hat, feed_dict["count"])

            # inserting detached loss in the expanded_feed_dict for deterministic latent model
            #Pdb().set_trace()
            if 'loss_matrix' in output_dict:
                expanded_feed_dict['loss'] = unravel_tensor(
                    output_dict['loss_matrix'], feed_dict['count']).detach()
                if args.latent_model == 'eg':
                    expanded_feed_dict[
                        'minloss_eg_prob'] = unravel_minloss_epsilon_greedy(
                            output_dict['loss_matrix'], feed_dict['count'],
                            args.minloss_eg_eps).detach()
            # compute latent variable, i.e. the scores for each of the possible targets
            z_latent = self._latent_model(expanded_feed_dict, y_hat,
                                          output_dict)['latent_z']

            # start index and end index are markers for start and end indices
            # of each query in the expanded feed dict
            start_index = torch.cumsum(feed_dict["count"],
                                       0) - feed_dict["count"]
            end_index = torch.cumsum(feed_dict["count"], 0)

            min_indices = []
            action_prob = []
            #rl_weights = []
            weights = []

            # loop over each query
            for s, e in zip(start_index, end_index):
                dis2 = z_latent[s:e].squeeze(1)
                probs = get_prob_from_dis(dis2)
                weights.append(
                    F.pad(probs,
                          (0, feed_dict['target_set'].size(1) - probs.size(0)),
                          "constant", 0))
            #
            selected_feed_dict = feed_dict
            if args.rl_exploration:
                selected_feed_dict["weights"] = rl_sampling(
                    torch.stack(weights).detach().clone())
            else:
                selected_feed_dict["weights"] = torch.stack(
                    weights).detach().clone()

            loss = 0
            if not args.no_static:
                # Pdb().set_trace()
                loss, monitors, output_dict = self._model(selected_feed_dict)
            else:
                loss = (output_dict['loss_matrix'] *
                        selected_feed_dict['weights']
                        ).sum() / selected_feed_dict['weights'].sum()

            if (feed_dict['is_ambiguous'].sum() > 0):
                if not args.rl_exploration:
                    avg_reward = (
                        (output_dict['reward'] *
                         (feed_dict['mask'].float())).sum(dim=1) /
                        (feed_dict['mask'].sum(dim=1).float())).unsqueeze(-1)
                    #avg_reward = (output_dict['reward']*(feed_dict['mask'].float())).sum()/(feed_dict['mask'].sum().float())
                    rewards = (output_dict['reward'] -
                               avg_reward) * (feed_dict['mask'].float())
                    rl_loss = -1.0 * (rewards * torch.stack(weights)).sum(
                    ) / feed_dict['is_ambiguous'].sum()
                else:
                    #use selected_feed_dict['weights']. rewards should be only for non zero samples.
                    #Also, now we use REINFORCE : maximize : reward*log(p_action)
                    rl_loss = -1.0 * (
                        (output_dict['reward'] + 0.5) *
                        selected_feed_dict['weights'] * torch.log(
                            torch.stack(weights) + 1.0 -
                            selected_feed_dict['weights'])
                    ).sum() / feed_dict['is_ambiguous'].sum().float()
            loss_latent = rl_loss

        self.trigger_event('forward:after', self, feed_dict, loss, monitors,
                           output_dict)

        loss = reduce_func('loss', loss)
        loss_f = as_float(loss)

        monitors = {k: reduce_func(k, v) for k, v in monitors.items()}
        if self.mode == 'hot':
            monitors['loss_latent'] = loss_latent
        monitors_f = as_float(monitors)

        self._optimizer.zero_grad()
        if self.mode in ['hot']:
            if torch.is_tensor(loss_latent):
                loss_latent = reduce_func('loss_latent', loss_latent)
            #
            self._latent_optimizer.zero_grad()

        self.trigger_event('backward:before', self, feed_dict, loss, monitors,
                           output_dict)

        if loss.requires_grad:
            loss.backward()

        if self.mode in ['hot']:
            if torch.is_tensor(loss_latent):
                loss_latent.backward()
                # print("Grad:",self._latent_model.digit_embed.weight.grad[2,:2],self._latent_model.atn_across_steps.grad)
                # Pdb().set_trace()
                #print('Latent: ',self.digit_embed.weight.data[2,:4], self.row_embed.weight.data[2,:4])
                #print('Atn over steps: ',self.atn_across_steps)

        self.trigger_event('backward:after', self, feed_dict, loss, monitors,
                           output_dict)

        loss_latent_f = loss_latent.item() if torch.is_tensor(
            loss_latent) else loss_latent
        grad_norm_before_clip, grad_norm_after_clip, param_norm_before_clip, lgrad_norm_before_clip, lgrad_norm_after_clip, lparam_norm_before_clip = 0, 0, 0, -1, -1, 0

        if loss.requires_grad:
            grad_norm_before_clip, grad_norm_after_clip, param_norm_before_clip = utils.gradient_normalization(
                self._model, grad_norm=args.grad_clip)
            #glogger.info(','.join(map(lambda x: str(round(x,6)),[self.current_epoch, self.num_iters, loss_f, loss_latent_f, grad_norm_before_clip.item(), grad_norm_after_clip.item(), param_norm_before_clip.item()])))
            if grad_norm_before_clip <= args.upper_limit_on_grad_norm:
                self._optimizer.step()
            else:
                self.num_bad_updates += 1
                logger.info(
                    'not taking optim step. Grad too high {}. Num bad updates: {}'
                    .format(round(grad_norm_before_clip, 2),
                            self.num_bad_updates))

            #self._optimizer.step()

        if self.mode in ['hot']:
            lgrad_norm_before_clip, lgrad_norm_after_clip, lparam_norm_before_clip = utils.gradient_normalization(
                self._latent_model, grad_norm=args.grad_clip)
            self._latent_optimizer.step()

        glogger.info(','.join(
            map(lambda x: str(round(x, 6)), [
                self.current_epoch, self.num_iters, loss_f, loss_latent_f,
                grad_norm_before_clip, grad_norm_after_clip,
                param_norm_before_clip, lgrad_norm_before_clip,
                lgrad_norm_after_clip, lparam_norm_before_clip
            ])))
        end = time.time()

        self.trigger_event('step:after', self)

        return loss_f, monitors_f, output_dict, {'time/gpu': end - begin}
Example #22
0
def _rms(p):
    return as_float((p**2).mean()**0.5)
def validate_attribute(model,
                       val_dataloader,
                       meters,
                       meter_prefix='validation',
                       logger=None,
                       output_attr_path=''):
    end = time.time()
    video_num = len(val_dataloader)
    #pdb.set_trace()
    with tqdm_pbar(total=int(len(val_dataloader) * args.batch_size /
                             128)) as pbar:
        output_dict_list = []
        frame_id_list = []
        for feed_dict_list in val_dataloader:
            #for vid in range(video_num):
            end_frm_flag = False
            #while (not end_frm_flag):
            for idx, feed_dict in enumerate(feed_dict_list):
                scene_idx = feed_dict['meta_ann']['scene_index']
                full_path = os.path.join(
                    output_attr_path,
                    'attribute_' + str(scene_idx).zfill(5) + '.json')
                if os.path.isfile(full_path):
                    print('File exists. %s\n' % (full_path))
                    tmp_dict = jsonload(full_path)
                    if len(tmp_dict) == len(
                            feed_dict['tube_info']['box_seq']['tubes'][0]):
                        continue
                    print('size didn\'t match. %s\n' % (full_path))
                    #pdb.set_trace()
                if args.use_gpu:
                    if not args.gpu_parallel:
                        feed_dict = async_copy_to(feed_dict, 0)
                frm_id = feed_dict['frm_id']
                data_time = time.time() - end
                end = time.time()

                f_scene = model.resnet(feed_dict['img'])
                f_sng = model.scene_graph(f_scene, feed_dict)
                output_dict = parse_scene(feed_dict, f_sng,
                                          model.reasoning.embedding_attribute,
                                          frm_id)
                #pdb.set_trace()
                output_dict_list.append(output_dict)
                frame_id_list.append(frm_id)

                step_time = time.time() - end
                end = time.time()
                if frm_id == len(
                        feed_dict['tube_info']['box_seq']['tubes'][0]) - 1:
                    video_attr_list = []
                    for idx, result_dict in enumerate(output_dict_list):
                        mon_dict = result_dict.pop('monitors')
                        result_dict['frm_id'] = frame_id_list[idx]
                        video_attr_list.append(result_dict)
                        monitors = {
                            meter_prefix + '/' + k: v
                            for k, v in as_float(mon_dict).items()
                        }

                        n = 1
                        meters.update(monitors, n=n)
                        meters.update({
                            'time/data': data_time,
                            'time/step': step_time
                        })

                    jsondump(full_path, video_attr_list)

                    if args.use_tb:
                        meters.flush()

                    pbar.set_description(
                        meters.format_simple('({})'.format(args.setname), {
                            k: v
                            for k, v in meters.val.items()
                            if k.startswith('validation') and k.count('/') <= 2
                        },
                                             compressed=True))
                    pbar.update()

                    end = time.time()
                    output_dict_list = []
                    frame_id_list = []
                    if logger is not None:
                        logger.critical(
                            meters.format_simple(meter_prefix, {
                                k: v
                                for k, v in meters.avg.items() if v != 0
                            },
                                                 compressed=False))