Beispiel #1
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)
        states = None
        if args.is_path_task:
            states = feed_dict.states.float()
            relations = feed_dict.relations.float()
        elif args.is_sort_task:
            relations = feed_dict.states.float()

        def get_features(states, relations, depth=None):
            inp = [None for i in range(args.nlm_breadth + 1)]
            inp[1] = states
            inp[2] = relations
            features = self.features(inp, depth=depth)
            return features

        if args.model == 'memnet':
            f = self.feature(relations, states)
        else:
            f = get_features(states, relations)[self.feature_axis]
        if self.feature_axis == 2:  #sorting task
            f = meshgrid_exclude_self(f)

        logits = self.pred(f).squeeze(dim=-1).view(relations.size(0), -1)
        # Set minimal value to avoid loss to be nan.
        policy = F.softmax(logits, dim=-1).clamp(min=1e-20)

        if self.training:
            loss, monitors = self.loss(policy, feed_dict.actions,
                                       feed_dict.discount_rewards,
                                       feed_dict.entropy_beta)
            return loss, monitors, dict()
        else:
            return dict(policy=policy, logits=logits)
Beispiel #2
0
    def get_binary_relations(self, states, depth=None):
        """get binary relations given states, up to certain depth."""
        # total = 2 * the number of objects in each world
        total = states.size()[1]
        f = self.transform(states)
        if args.model == 'memnet':
            f = self.feature(f)
        else:
            inp = [None for i in range(args.nlm_breadth + 1)]
            inp[2] = f
            features = self.features(inp, depth=depth)
            f = features[self.feature_axis]

        assert total % 2 == 0
        nr_objects = total // 2
        if args.concat_worlds:
            # To concat the properties of blocks with the same id in both world.
            f = torch.cat([f[:, :nr_objects], f[:, nr_objects:]], dim=-1)
            states = torch.cat(
                [states[:, :nr_objects], states[:, nr_objects:]], dim=-1)
            transformed_input = self.transform(states)
            # And perform a 'concat' transform to binary representation (relations).
            f = torch.cat([self.final_transform(f), transformed_input], dim=-1)
        else:
            f = f[:, :nr_objects, :nr_objects].contiguous()

        f = meshgrid_exclude_self(f)
        return f
Beispiel #3
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        # properties
        if args.task_is_adjacent:
            states = feed_dict.states.float()
        else:
            states = None

        # relations
        relations = feed_dict.relations.float()
        batch_size, nr = relations.size()[:2]

        if args.model == 'nlm':
            if args.task_is_adjacent and args.task_is_mnist_input:
                states_shape = states.size()
                states = states.view((-1, ) + states_shape[2:])
                states = self.lenet(states)
                states = states.view(states_shape[:2] + (-1, ))
                states = F.sigmoid(states)

            inp = [None for _ in range(args.nlm_breadth + 1)]
            inp[1] = states
            inp[2] = relations

            depth = None
            if args.nlm_recursion:
                depth = 1
                while 2**depth + 1 < nr:
                    depth += 1
                depth = depth * 2 + 1
            feature = self.features(inp, depth=depth)[self.feature_axis]
        elif args.model == 'memnet':
            feature = self.feature(relations, states)
            if args.task_is_adjacent and args.task_is_mnist_input:
                raise NotImplementedError()

        pred = self.pred(feature)
        if not args.task_is_adjacent:
            pred = pred.squeeze(-1)
        if args.task_is_connectivity:
            pred = meshgrid_exclude_self(pred)  # exclude self-cycle

        if self.training:
            monitors = dict()
            target = feed_dict.target.float()

            if args.task_is_adjacent:
                target = target[:, :, :args.adjacent_pred_colors]

            monitors.update(binary_accuracy(target, pred, return_float=False))

            loss = self.loss(pred, target)
            # ohem loss is unused.
            if args.ohem_size > 0:
                loss = loss.view(-1).topk(args.ohem_size)[0].mean()
            return loss, monitors, dict(pred=pred)
        else:
            return dict(pred=pred)
Beispiel #4
0
    def get_binary_relations(self, states, depth=None):
        """get binary relations given states, up to certain depth."""
        more_info = None
        if args.task in ['final', 'stack']:
            # total = 2 * the number of objects in each world
            total = states.size()[1]
            f = self.transform(states)
        else:
            f = states

        if args.model == 'memnet':
            f = self.feature(f)
        else:
            inp = [None for i in range(args.nlm_breadth + 1)]
            if type(f) is not list:
                inp[2] = f
            else:
                inp[1] = f[0]
                inp[2] = f[1]
            if args.model == 'dlm':
                if not self.training and args.extract_path:
                    self.features.extract_graph(self.feature_axis, self.pred)
                    for i in range(len(inp)):
                        if inp[i] is None:
                            continue
                        inp[i] = inp[i].bool()
                features = self.features(inp, depth=depth, extract_rule=args.extract_rule)
                f = features[0][self.feature_axis]
                more_info = features[1]
            else:
                features = self.features(inp, depth=depth)
                f = features[self.feature_axis]

        if args.task == 'final':
            assert total % 2 == 0
            nr_objects = total // 2
            if args.concat_worlds:
                # To concat the properties of blocks with the same id in both world.
                f = torch.cat([f[:, :nr_objects], f[:, nr_objects:]], dim=-1)
                states = torch.cat([states[:, :nr_objects], states[:, nr_objects:]], dim=-1)
                transformed_input = self.transform(states)
                # And perform a 'concat' transform to binary representation (relations).
                f = torch.cat([self.final_transform(f), transformed_input], dim=-1)
            else:
                f = f[:, :nr_objects, :nr_objects].contiguous()
        elif args.task == 'stack' or 'nlrl' in args.task:
            nr_objects = total if args.task == 'stack' else states[0].size()[1]
            f = f[:, :nr_objects, :nr_objects].contiguous()
        elif args.task in ['sort', 'path']:
            pass
        else:
            raise ()

        if args.task != 'path':
            f = meshgrid_exclude_self(f)
        return f, more_info
Beispiel #5
0
    def forward(self, feed_dict):
        feed_dict = GView(feed_dict)

        # properties
        if args.task_is_adjacent:
            states = feed_dict.states.float()
        else:
            states = None

        # relations
        relations = feed_dict.relations.float()
        batch_size, nr = relations.size()[:2]

        other_outputs = {}
        if args.model == 'nlm' or args.model == 'dlm':
            if args.task_is_adjacent and args.task_is_mnist_input:
                states_shape = states.size()
                states = states.view((-1, ) + states_shape[2:])
                states = self.lenet(states)
                states = states.view(states_shape[:2] + (-1, ))
                states = F.sigmoid(states)

            inp = [None for _ in range(args.nlm_breadth + 1)]
            inp[1] = states
            inp[2] = relations

            depth = None
            if args.nlm_recursion:
                depth = 1
                while 2**depth + 1 < nr:
                    depth += 1
                depth = depth * 2 + 1
            if args.model == 'dlm':
                #extract path here for self.pred
                if args.extract_path:
                    self.features.extract_graph(self.feature_axis, self.pred)
                    for i in range(len(inp)):
                        if inp[i] is None:
                            continue
                        inp[i] = inp[i].bool()
                feature, other_output = self.features(
                    inp, depth=depth, extract_rule=args.extract_rule)
                feature = feature[self.feature_axis]
                update_dict_list(other_outputs, other_output)
            else:
                feature = self.features(inp, depth=depth)[self.feature_axis]
        elif args.model == 'memnet':
            feature = self.feature(relations, states)
            if args.task_is_adjacent and args.task_is_mnist_input:
                raise NotImplementedError()

        if args.model == 'dlm':
            if args.extract_rule:
                print("last layer")
            pred, other_output = self.pred(feature)
            if args.extract_rule:
                print(self.pred.weight.argmax(-1))
            update_dict_list(other_outputs, other_output)
        else:
            pred = self.pred(feature)

        if not args.task_is_adjacent:
            pred = pred.squeeze(-1)
        if args.task_is_connectivity:
            pred = meshgrid_exclude_self(pred)  # exclude self-cycle

        if self.training:
            monitors = dict()
            target = feed_dict.target.float()

            if args.task_is_adjacent:
                target = target[:, :, :args.adjacent_pred_colors]

            monitors.update(binary_accuracy(target, pred, return_float=False))
            if args.model == 'dlm':
                # to stabilize BCELoss
                pred = 1e-5 + pred * (1.0 - 2e-5)

                saturation = torch.cat(
                    [a.flatten() for a in other_outputs['saturation']])
                monitors.update({'saturation': saturation})
                monitors.update({'tau': np.array(self.tau)})
                monitors.update({'gumbel_prob': np.array(self.gumbel_prob)})
                monitors.update({'dropout_prob': np.array(self.dropout_prob)})

            loss = self.loss(pred, target)
            if args.model == 'dlm' and args.entropy_reg != 0.0:
                entropies = torch.cat(
                    [a.flatten() for a in other_outputs['entropies']])
                loss += args.entropy_reg * entropies.mean()

            # ohem loss is unused.
            if args.ohem_size > 0:
                loss = loss.view(-1).topk(args.ohem_size)[0].mean()
            return loss, monitors, dict(pred=pred)
        else:
            return dict(pred=pred)