def __init__(self): super().__init__() self.feature_axis = 1 if args.is_path_task else 2 if args.model == 'memnet': current_dim = 4 if args.is_path_task else 6 self.feature = MemoryNet.from_args(current_dim, self.feature_axis, args, prefix='memnet') current_dim = self.feature.get_output_dim() else: input_dims = [0 for i in range(args.nlm_breadth + 1)] if args.is_path_task: input_dims[1] = 2 input_dims[2] = 2 elif args.is_sort_task: input_dims[2] = 6 self.features = LogicMachine.from_args(input_dims, args.nlm_attributes, args, prefix='nlm') if args.is_path_task: current_dim = self.features.output_dims[1] elif args.task == 'sort': current_dim = self.features.output_dims[2] self.pred = LogitsInference(current_dim, 1, []) self.loss = REINFORCELoss() self.pred_loss = nn.BCELoss()
def __init__(self): super().__init__() # inputs input_dim = 4 if args.task_is_family_tree else 1 self.feature_axis = 1 if args.task_is_1d_output else 2 # features if args.model == 'nlm': input_dims = [0 for _ in range(args.nlm_breadth + 1)] if args.task_is_adjacent: input_dims[1] = args.gen_graph_colors if args.task_is_mnist_input: self.lenet = LeNet() input_dims[2] = input_dim self.features = LogicMachine.from_args(input_dims, args.nlm_attributes, args, prefix='nlm') output_dim = self.features.output_dims[self.feature_axis] elif args.model == 'dlm': input_dims = [0 for _ in range(args.nlm_breadth + 1)] if args.task_is_adjacent: input_dims[1] = args.gen_graph_colors if args.task_is_mnist_input: self.lenet = LeNet() input_dims[2] = input_dim self.features = DifferentiableLogicMachine.from_args( input_dims, args.nlm_attributes, args, prefix='nlm') output_dim = self.features.output_dims[self.feature_axis] self.tau = 1.0 self.gumbel_prob = 1.0 self.dropout_prob = 0.1 elif args.model == 'memnet': if args.task_is_adjacent: input_dim += args.gen_graph_colors self.feature = MemoryNet.from_args(input_dim, self.feature_axis, args, prefix='memnet') output_dim = self.feature.get_output_dim() # target target_dim = args.adjacent_pred_colors if args.task_is_adjacent else 1 if args.model == 'dlm': self.pred = DLMInferenceBase(output_dim, target_dim, False, 'root') else: self.pred = LogicInference(output_dim, target_dim, []) # losses if args.ohem_size > 0: from jactorch.nn.losses import BinaryCrossEntropyLossWithProbs as BCELoss self.loss = BCELoss(average='none') else: self.loss = nn.BCELoss()
def __init__(self): super().__init__() # The 4 dimensions are: world_id, block_id, coord_x, coord_y input_dim = 4 self.transform = InputTransform('cmp', exclude_self=False) # current_dim = 4 * 3 = 12 current_dim = transformed_dim = self.transform.get_output_dim( input_dim) self.feature_axis = 1 if args.concat_worlds else 2 if args.model == 'memnet': self.feature = MemoryNet.from_args(current_dim, self.feature_axis, args, prefix='memnet') current_dim = self.feature.get_output_dim() else: input_dims = [0 for _ in range(args.nlm_breadth + 1)] input_dims[2] = current_dim self.features = LogicMachine.from_args(input_dims, args.nlm_attributes, args, prefix='nlm') current_dim = self.features.output_dims[self.feature_axis] self.final_transform = InputTransform('concat', exclude_self=False) if args.concat_worlds: current_dim = (self.final_transform.get_output_dim(current_dim) + transformed_dim) * 2 self.pred_valid = LogicInference(current_dim, 1, []) self.pred = LogitsInference(current_dim, 1, []) self.loss = REINFORCELoss() self.pred_loss = nn.BCELoss()
'exclude_self': True, 'logic_hidden_dim': [] }, prefix='nlm') nlm_group.add_argument( '--nlm-attributes', type=int, default=8, metavar='N', help= 'number of output attributes in each group of each layer of the LogicMachine' ) # MemNN parameters, works when model is 'memnet' memnet_group = parser.add_argument_group('Memory Networks') MemoryNet.make_memnet_parser(memnet_group, {}, prefix='memnet') # task related task_group = parser.add_argument_group('Task') task_group.add_argument('--task', required=True, choices=TASKS, help='tasks choices') task_group.add_argument('--train-number', type=int, default=10, metavar='N', help='size of training instances') task_group.add_argument('--adjacent-pred-colors', type=int, default=4,
def __init__(self): super().__init__() self.transform = InputTransform('cmp', exclude_self=False) input_dims = None # The 4 dimensions are: world_id, block_id, coord_x, coord_y if args.task == 'final': input_dim = 4 # current_dim = 4 * 3 = 12 current_dim = transformed_dim = self.transform.get_output_dim(input_dim) self.feature_axis = 1 if args.concat_worlds else 2 elif args.task == 'stack': input_dim = 2 current_dim = transformed_dim = self.transform.get_output_dim(input_dim) self.feature_axis = 2 elif args.task == 'sort': self.feature_axis = 2 current_dim = transformed_dim = 6 elif args.task == 'path': self.feature_axis = 1 input_dims = [0 for _ in range(args.nlm_breadth + 1)] input_dims[1] = 2 input_dims[2] = 2 transformed_dim = [0, 2, 2] elif 'nlrl' in args.task: self.feature_axis = 2 input_dims = [0 for _ in range(args.nlm_breadth + 1)] input_dims[1] = 2 # unary: isFloor & top if args.task == 'nlrl-On': input_dims[2] = 2 # binary: goal_on & on transformed_dim = [0, 2, 2] else: input_dims[2] = 1 # binary: on transformed_dim = [0, 2, 1] else: raise () if args.model in ['dlm', 'nlm'] and input_dims is None: input_dims = [0 for _ in range(args.nlm_breadth + 1)] input_dims[2] = current_dim if args.model == 'memnet': self.feature = MemoryNet.from_args(current_dim, self.feature_axis, args, prefix='memnet') current_dim = self.feature.get_output_dim() elif args.model == 'nlm': self.features = LogicMachine.from_args(input_dims, args.nlm_attributes, args, prefix='nlm') current_dim = self.features.output_dims[self.feature_axis] elif args.model == 'dlm': self.features = DifferentiableLogicMachine.from_args(input_dims, args.nlm_attributes, args, prefix='nlm') current_dim = self.features.output_dims[self.feature_axis] else: raise () self.final_transform = InputTransform('concat', exclude_self=False) if args.task == 'final': if args.concat_worlds: current_dim = (self.final_transform.get_output_dim(current_dim) + transformed_dim) * 2 if args.model == 'dlm': self.pred_valid = DLMInferenceBase(current_dim, 1, False, 'root_valid') self.pred = DLMInferenceBase(current_dim, 1, False, 'root') if args.distribution == 2: self.ac_selector = ActionSelector(current_dim) self.tau = args.tau_begin self.dropout_prob = args.dropout_prob_begin self.gumbel_prob = args.gumbel_noise_begin self.update_stoch() else: # args.model == 'nlm' self.pred_valid = LogicInference(current_dim, 1, []) self.pred = LogitsInference(current_dim, 1, []) if args.reinforce_log: self.loss = REINFORCELogLoss() else: self.loss = REINFORCELoss() self.pred_loss = nn.BCELoss() self.force_decay = False self.rnorm = RunningMeanStd(shape=1)