def remote_variables(self):
   train = list(
       snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
   train += list(
       snt.get_variables_in_module(self,
                                   tf.GraphKeys.MOVING_AVERAGE_VARIABLES))
   return train
Beispiel #2
0
    def get_trainable_vars(self):
        """
        Returns a list of the variables that are trainable.

        If a value for `fine_tune_from` is specified in the config, only the
        variables starting from the first that contains this string in its name
        will be trainable. For example, specifying `vgg_16/fc6` for a VGG16
        will set only the variables in the fully connected layers to be
        trainable.
        If `fine_tune_from` is None, then all the variables will be trainable.

        Returns:
            trainable_variables: a tuple of `tf.Variable`.
        """
        all_variables = snt.get_variables_in_module(self)

        fine_tune_from = self._config.get('fine_tune_from')
        if fine_tune_from is None:
            return all_variables

        # Get the index of the first trainable variable
        var_iter = enumerate(v.name for v in all_variables)
        try:
            index = next(i for i, name in var_iter if fine_tune_from in name)
        except StopIteration:
            raise ValueError(
                '"{}" is an invalid value of fine_tune_from for this '
                'architecture.'.format(fine_tune_from)
            )

        return all_variables[index:]
Beispiel #3
0
 def w(self):
   var_list = snt.get_variables_in_module(self)
   if self.use_bias:
     assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
   else:
     assert len(var_list) == 1, "Found not 1 but %d" % len(var_list)
   w = [x for x in var_list if self._raw_name(x.name) == "w"]
   assert len(w) == 1
   return w[0]
Beispiel #4
0
  def local_variables(self):
    """List of variables that need to be updated for each evaluation.

    These variables should not be stored on a parameter server and
    should be reset every computation of a meta_objective loss.

    Returns:
      vars: list of tf.Variable
    """
    return list(
        snt.get_variables_in_module(self, tf.GraphKeys.TRAINABLE_VARIABLES))
Beispiel #5
0
  def testQueryInModule(self):
    module = snt.Linear(output_size=42, name="linear")

    with self.assertRaisesRegexp(snt.Error, "not instantiated yet"):
      module.get_variables()

    # Compare to the desired result set, after connection.
    input_ = tf.placeholder(tf.float32, shape=[3, 4])
    _ = module(input_)
    self.assertEqual(set(module.get_variables()), {module.w, module.b})
    self.assertEqual(set(snt.get_variables_in_module(module)),
                     {module.w, module.b})
Beispiel #6
0
def save(network, sess, filename=None):
    """Save the variables contained by a network to disk."""
    to_save = collections.defaultdict(dict)
    variables = snt.get_variables_in_module(network)

    for v in variables:
        split = v.name.split(":")[0].split("/")
        module_name = split[-2]
        variable_name = split[-1]
        to_save[module_name][variable_name] = v.eval(sess)

    if filename:
        with open(filename, "wb") as f:
            pickle.dump(to_save, f)

    return to_save
Beispiel #7
0
    def load_weights(self):
        """
        Creates operations to load weights from checkpoint for each of the
        variables defined in the module. It is assumed that all variables
        of the module are included in the checkpoint but with a different
        prefix.

        Returns:
            load_op: Load weights operation or no_op.
        """
        if self._config.get('weights') is None and \
           not self._config.get('download'):
            return tf.no_op(name='not_loading_base_network')

        if self._config.get('weights') is None:
            # Download the weights (or used cached) if not specified in the
            # config file.
            # Weights are downloaded by default to the ~/.luminoth folder if
            # running locally, or to the job bucket if running in Google Cloud.
            self._config['weights'] = get_checkpoint_file(self._architecture)

        module_variables = snt.get_variables_in_module(
            self, tf.GraphKeys.MODEL_VARIABLES
        )
        assert len(module_variables) > 0

        load_variables = []
        variables = [(v, v.op.name) for v in module_variables]
        variable_scope_len = len(self.variable_scope.name) + 1
        for var, var_name in variables:
            checkpoint_var_name = var_name[variable_scope_len:]
            var_value = tf.contrib.framework.load_variable(
                self._config['weights'], checkpoint_var_name
            )
            load_variables.append(
                tf.assign(var, var_value)
            )

        tf.logging.info(
            'Constructing op to load {} variables from pretrained '
            'checkpoint {}'.format(
                len(load_variables), self._config['weights']
            ))

        load_op = tf.group(*load_variables)

        return load_op
Beispiel #8
0
    def get_trainable_vars(self):
        """Get trainable vars included in the module."""
        trainable_vars = snt.get_variables_in_module(self)
        if self._config.model.base_network.trainable:
            pretrained_trainable_vars = self.base_network.get_trainable_vars()
            if len(pretrained_trainable_vars):
                tf.logging.info("Training {} vars from pretrained module; "
                                'from "{}" to "{}".'.format(
                                    len(pretrained_trainable_vars),
                                    pretrained_trainable_vars[0].name,
                                    pretrained_trainable_vars[-1].name,
                                ))
            else:
                tf.logging.info("No vars from pretrained module to train.")
            trainable_vars += pretrained_trainable_vars
        else:
            tf.logging.info("Not training variables from pretrained module")

        return trainable_vars
    def get_trainable_vars(self):
        trainable_vars = snt.get_variables_in_module(self)
        if self._config.model.base_network.trainable:
            pretrained_trainable_vars = self.base_network.get_trainable_vars()
            if len(pretrained_trainable_vars):
                tf.logging.info(
                    'Training {} vars from pretrained module; '
                    'from "{}" to "{}".'.format(
                        len(pretrained_trainable_vars),
                        pretrained_trainable_vars[0].name,
                        pretrained_trainable_vars[-1].name,
                    )
                )
            else:
                tf.logging.info('No vars from pretrained module to train.')
            trainable_vars += pretrained_trainable_vars
        else:
            tf.logging.info('Not training variables from pretrained module')

        return trainable_vars
Beispiel #10
0
    def load_weights(self):
        """
        Creates operations to load weights from checkpoint for each of the
        variables defined in the module. It is assumed that all variables
        of the module are included in the checkpoint but with a different
        prefix.

        Returns:
            load_op: Load weights operation or no_op.
        """
        if self._config.get('weights') is None and \
           not self._config.get('download'):
            return tf.no_op(name='not_loading_base_network')

        if self._config.get('weights') is None:
            # Download the weights (or used cached) if not specified in the
            # config file.
            # Weights are downloaded by default to the $LUMI_HOME folder if
            # running locally, or to the job bucket if running in Google Cloud.
            self._config['weights'] = get_checkpoint_file(self._architecture)

        module_variables = snt.get_variables_in_module(
            self, tf.GraphKeys.MODEL_VARIABLES)
        assert len(module_variables) > 0

        load_variables = []
        variables = [(v, v.op.name) for v in module_variables]
        variable_scope_len = len(self.variable_scope.name) + 1
        for var, var_name in variables:
            checkpoint_var_name = var_name[variable_scope_len:]
            var_value = tf.contrib.framework.load_variable(
                self._config['weights'], checkpoint_var_name)
            load_variables.append(tf.assign(var, var_value))

        tf.logging.info('Constructing op to load {} variables from pretrained '
                        'checkpoint {}'.format(len(load_variables),
                                               self._config['weights']))

        load_op = tf.group(*load_variables)

        return load_op
 def get_trainable_vars(self):
     return snt.get_variables_in_module(self)
Beispiel #12
0
  def meta_loss(self,
                make_loss,
                len_unroll,
                net_assignments=None,
                model_path = None,
                second_derivatives=False):
    """Returns an operator computing the meta-loss.

    Args:
      make_loss: Callable which returns the optimizee loss; note that this
          should create its ops in the default graph.
      len_unroll: Number of steps to unroll.
      net_assignments: variable to optimizer mapping. If not None, it should be
          a list of (k, names) tuples, where k is a valid key in the kwargs
          passed at at construction time and names is a list of variable names.
      second_derivatives: Use second derivatives (default is false).

    Returns:
      namedtuple containing (loss, update, reset, fx, x)
    """

    # Construct an instance of the problem only to grab the variables. This
    # loss will never be evaluated.
    
    sub_x, sub_constants=_get_variables(make_loss)
    print (sub_x, sub_constants)
#    pdb.set_trace()
#    print(len(sub_x))
    
    def intra_init(x):
      
#      print(self.x_minimal)
      fc_columns = 10
#      pdb.set_trace()
      
      if model_path == None:
        for i in range(len(x)):
#      fc_kernel_shape = ([20, 15680*2], [20, 20*2], [20, 200*2], [10, 10*2])
#      fc_bias_shape = ([20, self.intra_features ], [20, self.intra_features], [20, self.intra_features], [10, self.intra_features])
#      fc_va_shape=([1,20],[1,20],[1,20],[1,10])
#      pdb.set_trace()
      
#        x[i] = tf.reshape(x[i],[1,-1])
#        print(x[i])
          fc_shape = tf.reshape(x[i],[1,-1]).get_shape()
          print(fc_shape)
#            print(x[i])
#            print(size)
#            res =tf.Variable([fc_columns-1,0],trainable=False)
#            size = size + res
#            print(size)
          fc_kernel = np.random.rand(fc_columns, fc_shape[1]*2)
          print(fc_kernel.shape)
          fc_bias_shape = [fc_columns, self.intra_features ]
          fc_va_shape = [1,fc_columns]
#           ker = tf.Variable(tf.random_normal(fc_shape),trainable = False)
#           sub_ker = tf.concat([ker, ker],axis = 1)
#           fc_kernel = tf.concat([sub_ker for i in range(fc_columns)], axis = 0)
#           kernel_shape = tf.shape(fc_kernel)
#           print(kernel_shape)
          sub_fc_kernel = tf.Variable(tf.random_normal(fc_kernel.shape))
#           print(sub_fc_kernel)
          sub_fc_bias = tf.Variable(tf.random_normal(fc_bias_shape))
#           print(sub_fc_bias)
          sub_fc_va = tf.Variable(tf.ones(fc_va_shape), trainable = False)
#           print(sub_fc_va)
          self.fc_kernel.append(sub_fc_kernel)
          self.fc_bias.append(sub_fc_bias)
          self.fc_va.append(sub_fc_va)
      else:
        with open('./{}/loss_record.pickle'.format(model_path),'rb') as loss:
          data = pickle.load(loss)
        self.fc_kernel = [tf.Variable(item) for item in data['fc_weights']]
        self.fc_bias = [tf.Variable(item) for item in data['fc_bias']]
        for i in range(len(x)):
          fc_va_shape = [1,fc_columns]
          sub_fc_va = tf.Variable(tf.ones(fc_va_shape), trainable = False)
          self.fc_va.append(sub_fc_va)
    
    intra_init(sub_x)

    def vars_init(x):
#      pdb.set_trace()
      variables = []
      for i in range(len(x)):
        vars_shape = x[i].get_shape()
#        vars_shape = vars_shape.tolist()
        print(vars_shape)
        vars = np.random.rand(self.num_lstm,*vars_shape)
        print(vars.shape)
#        vars_shape = [item for item in vars_shape]
#        vars_shape.insert(0, self.num_lstm)
#        print(vars_shape)
#        var = tf.Variable(tf.ones(vars_shape), trainable = False)
#        sub_var = tf.stack([var for j in range(self.num_lstm)], axis = 0)
#        sub_var_shape = tf.shape(sub_var)
        sub_vars = tf.Variable(tf.random.normal(vars.shape, mean=0.0, stddev=0.01), trainable = False)
        #sub_vars = tf.Variable(tf.random.uniform(vars.shape, minval=-3, maxval=3), trainable = False)
#        sub_vars = tf.stack(sub_vars, axis = 0)
        variables.append(sub_vars)
      return variables
    
    x = vars_init(sub_x)
    constants = vars_init(sub_constants)
    self.x_minimal = vars_init(sub_x)
    print(sub_x, x[0].shape)
    print(constants)
    #exit(0)
    '''
    def intra_vars_init(x):
      variables = []
      for i in range(len(x)):
        sub_vars = []
        for j in range(self.num_lstm):
          vars_shape = tf.shape(x[i])
          var = tf.Variable(tf.random_normal(vars_shape, mean = 0, stddev = 0.01), trainable = False)
          sub_vars.append(var)
        sub_vars = tf.stack(sub_vars, axis = 0)
        variables.append(sub_vars)
      return variables
    '''
    self.pre_deltas = vars_init(sub_x)
    self.pre_gradients = vars_init(sub_x)
    
#    x=[sub_x for i in range(self.num_lstm)]
#    constants=[sub_constants for i in range(self.num_lstm)]
    
    
    print("x.length",len(x), x[0].shape, type(x[0]))
    #np.savetxt("xlength", x)
    print("Optimizee variables")
    print([op.name for op in sub_x])
    print("Problem variables")
    print([op.name for op in sub_constants])
#
#    fc_kernel = []
#    fc_bias = []
#    fc_va = []
#    fc_kernel_shape = ([20, 15680*2], [20, 20*2], [20, 200*2], [10, 10*2])
#    fc_bias_shape = ([20, self.intra_features ], [20, self.intra_features], [20, self.intra_features], [10, self.intra_features])
#    fc_va_shape=([1,20],[1,20],[1,20],[1,10])
#    for i in range(4):
#      sub_fc_kernel = tf.Variable(tf.random_normal(fc_kernel_shape[i]))
#      sub_fc_bias = tf.Variable(tf.random_normal(fc_bias_shape[i]))
#      sub_fc_va = tf.Variable(tf.ones(fc_va_shape[i]), trainable = False)
#      fc_kernel.append(sub_fc_kernel)
#      fc_bias.append(sub_fc_bias)
#      fc_va.append(sub_fc_va)

#    x_loop = x
    # Create the optimizer networks and find the subsets of variables to assign
    # to each optimizer.
    nets, net_keys, subsets = _make_nets(sub_x, self._config, net_assignments) 

    # Store the networks so we can save them later.
    self._nets = nets

    # Create hidden state for each subset of variables.
    '''
    if len(subsets) > 1:
      state = []
    else:
      state=[[] for i in range(len(subsets))]
    '''
    state = []
    with tf.name_scope("states"):
      for i, (subset, key) in enumerate(zip(subsets, net_keys)):
        net = nets[key]
        with tf.name_scope("state_{}".format(i)):
          state.append(_nested_variable(
              [net.initial_state_for_inputs(tf.stack(x[j], axis = 0), dtype=tf.float32)
               for j in subset],
              name="state", trainable=False))
      '''
      if len(subsets) > 1:
        state.append(single_state) 
      else:
        for i in range(len(subsets))
          state[i].append(single_state[i]) 
      '''
#    pdb.set_trace()
#    print(x)
#    print(state)
    
    
    def update(t, net, gradients, x, attraction, state):
      """Parameter and RNN state update."""
      
   
      def attraction_init(mat):
#        pdb.set_trace()
        alpha = 1
        print(mat)
#        mat_split = tf.split(mat_x, self.num_lstm)
#        mat_split = [tf.reshape(item, [1, -1]) for item in mat_split]
        norm = []
        for i in range(self.num_lstm):
          print(mat[i])
          mat_norm = tf.reshape(mat[i], [self.num_lstm, -1])
          print(mat_norm)
          mat_l2 = tf.reduce_sum(tf.multiply(mat_norm, mat_norm), axis = 1)
          print(mat_l2)
          mat_l2 = -alpha*tf.reshape(mat_l2, [1, -1])
          print(mat_l2)
          mat_softmax = tf.nn.softmax(mat_l2)
          print(mat_softmax)
          norm.append(tf.matmul(mat_softmax, mat_norm))
#          sub_norm = tf.concat([mat_split[i] for j in range(self.num_lstm)], axis = 0)
#          sub_norm = tf.tile(mat_split[i], [self.num_lstm, 1])
#          sub_norm = sub_norm - tf.reshape(tf.stack(mat_y[i], \
#                                        axis= 0) , [self.num_lstm,-1])
#          attraction_norm = -alpha*tf.matmul(mat_split[i],tf.transpose(sub_norm))
#          attraction_intra = tf.nn.softmax(attraction_norm)
#          sub_attraction = tf.matmul(attraction_intra, sub_norm)
#          norm.append(sub_attraction)
        print(norm)
        attraction = tf.reshape(tf.stack(norm, axis = 0), tf.shape(mat[0]))
        print(attraction)
        return attraction
      
        
#      pdb.set_trace()
      print(attraction)
      x_attraction = [attraction_init(sub_attraction) for sub_attraction in attraction]
      print(x_attraction)
      def inter_attention(mat_a,mat_b):
#        pdb.set_trace()
        l=1
        gama=1/self.num_lstm
#        origin = tf.reshape(mat_a, [self.num_lstm, -1])
        origin = mat_a
        print(origin)
        grad = mat_b 
        print(grad)
        def x_res_init(mat):
#          pdb.set_trace()
          mat_split=tf.split(mat, self.num_lstm)
          norm = []
          for i in range(self.num_lstm):
#            sub_norm = tf.concat([mat_split[i] for j in range(self.num_lstm)], axis = 0)
            sub_norm = tf.tile(mat_split[i], [self.num_lstm, 1])
            sub_norm = sub_norm - mat
            norm.append(tf.matmul(mat_split[i],tf.transpose(sub_norm)))
          result = (-1/2*l)*tf.concat(norm, axis = 0)
          print(result)
          normalise = tf.nn.softmax(tf.transpose(result))
          return normalise

        def attention(Mat_a,Mat_b):
          result = gama*tf.matmul(Mat_a,Mat_b)
          Mat_b = tf.add(Mat_b,result)
          return Mat_b
#        pdb.set_trace()
        matmul1 = tf.matmul(grad,tf.transpose(grad))
        print(matmul1)
#        matmul2 = tf.matmul(origin,tf.transpose(origin))
#        matmul2 = tf.exp(matmul2/2*l)
#        print(matmul2)
        softmax_grad = tf.nn.softmax(tf.transpose(matmul1))
        print(softmax_grad)
#        softmax_grad = tf.transpose(softmax_grad)
        softmax_origin = x_res_init(origin)
        print(softmax_origin)
#        softmax_origin = tf.transpose(softmax_origin)
        input_mul_x = tf.matmul(softmax_grad,softmax_origin)
        e_ij = attention(input_mul_x,grad)
#        e_ij = attention(softmax_grad,grad)
        print(e_ij)
        return e_ij
      
      def intra_attention( grad, pre_grad, x, x_min, sub_x_attraction, ht, fc_kernel, fc_bias, fc_va):
#        pdb.set_trace()
#        print(x)
        
#        shape=([1,15680],[1,20],[1,200],[1,10])
#        reshape=([784,20],[20,],[20,10],[10,])
        beta = 0.9
#        sub_grad = tf.unstack(grad, axis = 0)
#        momentum = tf.unstack(beta*pre_grad, axis = 0)
#        intra_dim = len(grad.get_shape()) - 1
#        print(intra_dim)
        intra_feature = tf.concat([grad, beta*pre_grad, x - x_min, sub_x_attraction], axis=0)
        intra_feature = tf.reshape(intra_feature,[self.num_lstm*self.intra_features, -1])
#        intra_feature = tf.reshape(intra_feature, [self.intra_features, -1])
        print(intra_feature)
        ht_concat = tf.concat([ht for i in range(self.intra_features)], axis = 0)
        ht_concat = tf.reshape(ht_concat, [self.num_lstm*self.intra_features, -1])
        print(ht_concat)
        intra_concat = tf.concat([intra_feature, ht_concat], axis = 1)
        print(intra_concat)
#        intra_concat = tf.reshape(intra_concat, [self.intra_features, -1])
#        print(intra_concat)
        intra_fc_bias = tf.tile(fc_bias, [1, self.num_lstm])
        print(intra_fc_bias)
        intra_fc = tf.tanh(tf.matmul(fc_kernel,tf.transpose(intra_concat)) + intra_fc_bias)
        print(intra_fc)
        va = fc_va
        print(va)
        b_ij = tf.matmul(va,intra_fc)
        print(b_ij)
        p_ij = tf.nn.softmax(tf.reshape(b_ij, [self.num_lstm, -1]))
        print(p_ij)
        p_ij = tf.reshape(p_ij, [self.num_lstm*self.intra_features, 1])
        print(p_ij)
        gradient = tf.multiply(p_ij, intra_feature)
        gradient = tf.reshape(gradient, [self.intra_features, -1])
        gradient = tf.reduce_sum(gradient, axis = 0)
        print(gradient)
        gradient = tf.reshape(gradient, tf.shape(grad))
        print(gradient)
#        gradient = tf.reshape(gradient, reshape[i])
#        intra_fc_bias = tf.concat([fc_bias for i in range(self.num_lstm)], axis = 1)
#        sub_x = tf.reshape(x, [1,-1])
#        sub_x_min = tf.reshape(x_min, [1,-1])
#        sub_ht = tf.unstack(ht, axis = 0)
#        x_res = tf.unstack((x - x_min) ,axis = 0)
#        gradients = []
#        for j in range(self.num_lstm):
#          intra_feature = tf.concat([sub_grad[j], momentum[j], x_res[j]],axis=0)
#          print(intra_feature)
#          intra_feature = tf.reshape(intra_feature, [self.intra_features, -1])
#          print(intra_feature)
#          ht_concat = tf.concat([sub_ht[j] for i in range(self.intra_features)],axis = 0)
#          print(ht_concat)
#          ht_concat = tf.reshape(ht_concat, [self.intra_features, -1])
#          print(ht_concat)
##          grad_concat = tf.concat([sub_grad,sub_ht],axis=1)            
##          moment_concat = tf.concat([sub_moment,sub_ht],axis=1)
##          x_res_concat = tf.concat([x_res,sub_ht],axis=1)
##          intra_concat = tf.concat([grad_concat,moment_concat,x_res_concat],axis=0)
#          intra_concat = tf.concat([intra_feature, ht_concat], axis = 1)
#          print(intra_concat)
##        intra_concat = tf.transpose(intra_concat)
#          intra_fc=tf.tanh(tf.matmul(fc_kernel,tf.transpose(intra_concat)) + fc_bias)
#          va = fc_va
#          b_ij = tf.matmul(va,intra_fc)
#          p_ij = tf.nn.softmax(b_ij)
#          gradient = tf.matmul(p_ij, intra_feature)
#          gradient = tf.reshape(gradient, tf.shape(grad[j]))
#          gradients.append((gradient))
##        print(tf.stack(gradientsgrients, axis = 0)
        return gradient
      with tf.name_scope("gradients"):
#        pdb.set_trace()
#          print(fx[i])
        print(x)
            
        # Stopping the gradient here corresponds to what was done in the
        # original L2L NIPS submission. However it looks like things like
        # BatchNorm, etc. don't support second-derivatives so we still need
        # this term.
        if not second_derivatives:
          gradients = [tf.stop_gradient(g) for g in gradients]
        print(gradients)
#      with tf.name_scope("intra_attention"):
#           
#        print(sub_gradients)
#        print(x_min)
#        print(ht)
#        sub_gradients = intra_attention(sub_gradients, pre_grads, x[i], x_min, ht)

        
#      pdb.set_trace()
      print(gradients)
      with tf.name_scope("inter_attention"):
##x to matrix
        for i in range(len(x)):
#          xi = tf.stack(x[i], axis = 0)
          shape = tf.shape(x[i])
          mat_x = tf.reshape(x[i], [self.num_lstm,-1])
          
          mat_grads = tf.reshape(gradients[i], [self.num_lstm,-1])
          
##inter-attention
          inter_grads=inter_attention(mat_x, mat_grads)
          gradients[i] = tf.reshape(inter_grads, shape)

        
        print('mnist_gradients',gradients)


      with tf.name_scope("deltas"):
        
#        pdb.set_trace()
        x_min = self.x_minimal
        ht = self.pre_deltas
        pre_grads = self.pre_gradients
        deltas, state_next = zip(*[net(intra_attention( grad, pre_grad, x, x_min, sub_x_attraction, 
        ht, fc_kernel, fc_bias , fc_va), s)
        for grad, pre_grad, x, x_min, sub_x_attraction, ht, fc_kernel, fc_bias , fc_va, s in zip(gradients,
        pre_grads, x, x_min, x_attraction, ht, self.fc_kernel, self.fc_bias , self.fc_va, state)])
        self.pre_deltas = deltas
        print(deltas)
        print(state_next)
        state_next = list(state_next)
        self.pre_gradients=gradients
        print(state_next)
      
#      print(state_next)
      return deltas, state_next
#time_step的参数初始化返回的x_now代表当前mnist网络的x,x_next代表用lstm进行梯度更新后的mnist网络x   
#intra&inter time-step
#    pdb.set_trace()
#    print(x)
    def time_step(t, fx_array, x, x_array, state):
      """While loop body."""
#      pdb.set_trace()
#      print(x)
      x_next = list(x)
      state_next = []

#      fx_x = [tf.unstack(tensor, axis = 0) for tensor in x]
      
      with tf.name_scope("fx"):
        update_fx = []
        gradients = []
        for z in range(self.num_lstm):
          sub_x = [item[z] for item in x]
          sub_fx_batch = _make_with_custom_variables(make_loss, sub_x)
          sub_fx = tf.reduce_mean(sub_fx_batch)

          x_array = x_array.write(t*self.num_lstm + z, sub_x[0])
          fx_array = fx_array.write(t*self.num_lstm + z, sub_fx_batch)
          sub_gradients = tf.gradients(sub_fx, sub_x)
          print(sub_gradients)
          update_fx.append(sub_fx)
          gradients.append(sub_gradients)
        
#        pdb.set_trace()
        print(gradients)
        gradients = zip(*gradients)
        gradients = [tf.stack(list(gradient), axis = 0)for gradient in gradients]
        print(gradients)
#        pdb.set_trace()
        attraction = []
        for j in range(len(x)):
          attraction_x = []
          for ind1, item1 in enumerate(update_fx):
            sub_attraction_x = []
            for ind2, item2 in enumerate(update_fx):
#              def f1(): return sub_attraction_x
#              def f2(): sub_attraction_x = sub_attraction_x.remove(sub_attraction_x[ind2]) return sub_attraction_x
#              sub_attraction_x = tf.cond(tf.greater(item1, item2), lambda:f1(), lambda:f2())
              return_value = tf.where(tf.greater(item1, item2), x[j][ind1] - x[j][ind1], x[j][ind1] - x[j][ind2])
              print(return_value)
              sub_attraction_x.append(return_value)
              print(sub_attraction_x)
            attraction_x.append(tf.stack(sub_attraction_x, axis = 0))
            print(attraction_x)
          attraction.append(attraction_x)
        print(attraction)
#        print(x)
##        test_xg = [tf.unstack(tensor, axis = 0) for tensor in x]
#        for k in range(self.num_lstm):
#          test_x = [item[k] for item in x]
##          test_fx = _make_with_custom_variables(make_loss, test_x)
#          test_gradients = tf.gradients(update_fx[k], test_x)
#          print(test_gradients)
#          fx_array = fx_array.write(t*self.num_lstm + z, sub_fx)
        fx_sum = tf.reduce_sum(tf.stack(update_fx))
        def f1(): return self.fx_minimal, self.x_minimal
        def f2(): return fx_sum , x
        self.fx_minimal, self.x_minimal = tf.cond(tf.greater(fx_sum, self.fx_minimal), lambda:f1(), lambda:f2())
      with tf.name_scope("dx"):
        for subset, key, s_i in zip(subsets, net_keys, state):
          
#          pdb.set_trace()   
#          print(update_fx)
          x_i = [x[j] for j in subset]
          
#          print(x_i)
          deltas, s_i_next = update(t, nets[key], gradients, x_i, attraction, s_i)

          ratio=1.
          for idx, j in enumerate(subset):
            x_next[j] += deltas[idx]*ratio
          state_next.append(s_i_next)

      with tf.name_scope("t_next"):
        t_next = t + 1

      
        
      return t_next, fx_array, x_next, x_array, state_next

    
    # Define the while loop.
    fx_array = tf.TensorArray(tf.float32, size=(len_unroll + 1)*self.num_lstm,
                              clear_after_read=False)

    # we need x_array for calculating the entropy loss
    x_array = tf.TensorArray(tf.float32, size=(len_unroll + 1)*self.num_lstm,
                              clear_after_read=False)

    _, fx_array, x_final, x_array, s_final = tf.while_loop(
        cond=lambda t, *_: t < len_unroll,
        body=time_step,
        loop_vars=(0, fx_array, x, x_array, state),
        parallel_iterations=1,
        swap_memory=True,
        name="unroll")

    with tf.name_scope("fx"):
#      pdb.set_trace()
#      print('x_final',x_final)
      final_x = [tf.unstack(tensor, axis = 0) for tensor in x_final]
      print ("final_x", final_x[0][0].shape, x_final[0].shape)
      with tf.name_scope("fx"):
        fx_final = []
        for z in range(self.num_lstm):
          sub_x = [item[z] for item in final_x]
          print ("sub_x", type(sub_x[0]))
          sub_fx_final_batch = _make_with_custom_variables(make_loss, sub_x)
          sub_fx_final = tf.reduce_mean(sub_fx_final_batch)
          print ('sub_x[0]', sub_x[0].shape, sub_fx_final.shape)
          

          fx_array = fx_array.write(len_unroll*self.num_lstm + z, sub_fx_final_batch)
          x_array = x_array.write(len_unroll*self.num_lstm + z, sub_x[0])
          fx_final.append(sub_fx_final)
         


    print (x[0].shape, x_final[0].shape,  len(fx_final),'xinfal11', len_unroll)
    
    
    loss = entropy_loss.self_loss(x_array.stack(), fx_array.stack(), (len_unroll + 1)*self.num_lstm)
    #loss = tf.reduce_mean(tf.reduce_sum(fx_array.stack(), -1))
    #print (loss.shape)
    #exit(0)
    

    # Reset the state; should be called at the beginning of an epoch.
    
    # Reset the state; should be called at the beginning of an epoch.
    with tf.name_scope("reset"):
#      pdb.set_trace()
      variables = (nest.flatten(state) +
                   x + constants)
#      print(variables)
#      print(x)
      # Empty array as part of the reset process.
      reset = [tf.variables_initializer(variables), fx_array.close(), x_array.close()]

    # Operator to update the parameters and the RNN state after our loop, but
    # during an epoch.
    with tf.name_scope("update"):
      update = (nest.flatten(_nested_assign(x, x_final)) +
                nest.flatten(_nested_assign(state, s_final)))

    # Log internal variables.
    for k, net in nets.items():
      

      print("Optimizer '{}' variables".format(k))
      print([op.name for op in snt.get_variables_in_module(net)])
    
    print(fx_final)
    

    return MetaLoss(loss, update, reset, fx_final, x_final, constants)
Beispiel #13
0
    def meta_loss(self,
                  make_loss,
                  len_unroll,
                  net_assignments=None,
                  second_derivatives=False):
        """Returns an operator computing the meta-loss.

    Args:
      make_loss: Callable which returns the optimizee loss; note that this
          should create its ops in the default graph.
      len_unroll: Number of steps to unroll.
      net_assignments: variable to optimizer mapping. If not None, it should be
          a list of (k, names) tuples, where k is a valid key in the kwargs
          passed at at construction time and names is a list of variable names.
      second_derivatives: Use second derivatives (default is false).

    Returns:
      namedtuple containing (loss, update, reset, fx, x), ...
    """

        # Construct an instance of the problem only to grab the variables. This
        # loss will never be evaluated.
        # pdb.set_trace()

        x, constants = _get_variables(make_loss)

        print("Optimizee variables")
        print([op.name for op in x])
        print("Problem variables")
        print([op.name for op in constants])

        # create scale placeholder here
        scale = []
        for k in x:
            scale.append(
                tf.placeholder_with_default(tf.ones(shape=k.shape),
                                            shape=k.shape,
                                            name=k.name[:-2] + "_scale"))

        # Create the optimizer networks and find the subsets of variables to assign
        # to each optimizer.
        nets, net_keys, subsets = _make_nets(x, self._config, net_assignments)
        print('nets', nets)
        print('subsets', subsets)
        # Store the networks so we can save them later.
        self._nets = nets

        # Create hidden state for each subset of variables.
        state = []
        with tf.name_scope("states"):
            for i, (subset, key) in enumerate(zip(subsets, net_keys)):
                net = nets[key]
                with tf.name_scope("state_{}".format(i)):
                    state.append(
                        _nested_variable([
                            net.initial_state_for_inputs(x[j],
                                                         dtype=tf.float32)
                            for j in subset
                        ],
                                         name="state",
                                         trainable=False))

        def update(net, fx, x, state):
            """Parameter and RNN state update."""
            with tf.name_scope("gradients"):
                gradients = tf.gradients(fx, x)

                # Stopping the gradient here corresponds to what was done in the
                # original L2L NIPS submission. However it looks like things like
                # BatchNorm, etc. don't support second-derivatives so we still need
                # this term.
                if not second_derivatives:
                    gradients = [tf.stop_gradient(g) for g in gradients]

            with tf.name_scope("deltas"):
                deltas, state_next = zip(
                    *[net(g, s) for g, s in zip(gradients, state)])
                state_next = _nested_tuple(state_next)
                state_next = list(state_next)

            return deltas, state_next

        def time_step(t, fx_array, x, state):
            """While loop body."""
            x_next = list(x)
            state_next = []

            with tf.name_scope("fx"):
                scaled_x = [x[k] * scale[k] for k in range(len(scale))]
                fx = _make_with_custom_variables(make_loss, scaled_x)
                fx_array = fx_array.write(t, fx)

            with tf.name_scope("dx"):
                for subset, key, s_i in zip(subsets, net_keys, state):
                    x_i = [x[j] for j in subset]
                    deltas, s_i_next = update(nets[key], fx, x_i, s_i)

                    for idx, j in enumerate(subset):
                        delta = deltas[idx]
                        x_next[j] += delta
                    state_next.append(s_i_next)

            with tf.name_scope("t_next"):
                t_next = t + 1

            return t_next, fx_array, x_next, state_next

        # Define the while loop.
        fx_array = tf.TensorArray(tf.float32,
                                  size=len_unroll + 1,
                                  clear_after_read=False)
        _, fx_array, x_final, s_final = tf.while_loop(
            cond=lambda t, *_: t < len_unroll,
            body=time_step,
            loop_vars=(0, fx_array, x, state),
            parallel_iterations=1,
            swap_memory=True,
            name="unroll")

        with tf.name_scope("fx"):
            scaled_x_final = [x_final[k] * scale[k] for k in range(len(scale))]
            fx_final = _make_with_custom_variables(make_loss, scaled_x_final)
            fx_array = fx_array.write(len_unroll, fx_final)

        loss = tf.reduce_sum(fx_array.stack(), name="loss")

        ##################################
        ### multi task learning losses ###
        ##################################
        # state (num_subsets, num_x, (num_layers, (h,c)))
        # state_reshape (num_mt, num_subsets, (num_layers, (h, c)))
        state_reshape = []
        num_layers = len(state[0][0])
        for mti in range(self.num_mt):
            state_reshape_mti = []
            for state_subset in state:
                state_layers = ()
                for li in range(num_layers):
                    h = tf.concat([st_x[li][0] for st_x in state_subset],
                                  axis=0)
                    c = tf.concat([st_x[li][1] for st_x in state_subset],
                                  axis=0)
                    h = tf.Variable(h, name="state_reshape_h", trainable=False)
                    c = tf.Variable(c, name="state_reshape_c", trainable=False)
                    state_layers += ((h, c), )
                state_reshape_mti.append(state_layers)
            state_reshape.append(state_reshape_mti)
        if self.num_mt > 0:
            shapes = [
                st_subset[0][0].get_shape().as_list()[0]
                for st_subset in state_reshape[0]
            ]
        else:
            shapes = []
        num_params_total = sum(shapes)
        print("number of parameters = {}".format(num_params_total))

        # placeholder (num_mt, num_subsets, len_unroll, num_params)
        mt_labels = []
        mt_inputs = []
        for i in range(self.num_mt):
            mt_labels.append([
                tf.placeholder(dtype=tf.float32,
                               shape=(len_unroll, shapes[j]),
                               name="mt{}_label_subset{}".format(i, j))
                for j in range(len(subsets))
            ])
            mt_inputs.append([
                tf.placeholder(dtype=tf.float32,
                               shape=(len_unroll, shapes[j]),
                               name="mt{}_input_subset{}".format(i, j))
                for j in range(len(subsets))
            ])

        # loop
        def time_step_mt(mti):
            def time_step_func(t, loss_array, states):
                loss_t_sum = 0.0
                state_next = []
                for si, (k, st) in enumerate(zip(net_keys, states)):
                    net = nets[k]
                    g_input = tf.gather(mt_inputs[mti][si], indices=t, axis=0)
                    g_label = tf.gather(mt_labels[mti][si], indices=t, axis=0)
                    delta, state_next_si = net(g_input, st)
                    loss_t_sum += tf.reduce_sum(
                        (g_label - delta) * (g_label - delta)) * 0.5
                    state_next_si = _nested_tuple(state_next_si)
                    state_next.append(state_next_si)
                loss_t = loss_t_sum / num_params_total
                loss_array = loss_array.write(t, loss_t)
                t_next = t + 1
                return t_next, loss_array, state_next

            return time_step_func

        loss_arrays = [
            tf.TensorArray(tf.float32, size=len_unroll, clear_after_read=False)
            for _ in range(self.num_mt)
        ]
        state_reshape_final = []
        for mti in range(self.num_mt):
            loss_array = loss_arrays[mti]
            _, loss_array, state_reshape_final_mti = tf.while_loop(
                cond=lambda t, *_: t < len_unroll,
                body=time_step_mt(mti),
                loop_vars=(0, loss_array, state_reshape[mti]),
                parallel_iterations=1,
                swap_memory=True,
                name="unroll_mt")
            loss_arrays[mti] = loss_array
            state_reshape_final.append(state_reshape_final_mti)

        # loss
        loss_mt = [
            tf.reduce_sum(loss_array.stack(), name="loss_mt{}".format(i))
            for i, loss_array in enumerate(loss_arrays)
        ]

        # Reset the state; should be called at the beginning of an epoch.
        with tf.name_scope("reset"):
            variables = (nest.flatten(state) + x + constants)
            # Empty array as part of the reset process.
            reset = [tf.variables_initializer(variables), fx_array.close()]

            # mt
            variables_mt = [
                nest.flatten(state_reshape[mti]) for mti in range(self.num_mt)
            ]
            reset_mt = [[
                tf.variables_initializer(variables_mt[mti]),
                loss_arrays[mti].close()
            ] for mti in range(self.num_mt)]

        # Operator to update the parameters and the RNN state after our loop, but
        # during an epoch.
        with tf.name_scope("update"):
            update = (nest.flatten(_nested_assign(x, x_final)) +
                      nest.flatten(_nested_assign(state, s_final)))
            update_mt = [(nest.flatten(
                _nested_assign(state_reshape[mti], state_reshape_final[mti])))
                         for mti in range(self.num_mt)]

        # Log internal variables.
        for k, net in nets.items():
            print("Optimizer '{}' variables".format(k))
            print([op for op in snt.get_variables_in_module(net)])

        return MetaLoss(loss, update, reset, fx_final, x_final), scale, x, constants, subsets,\
               loss_mt, update_mt, reset_mt, mt_labels, mt_inputs
Beispiel #14
0
 def b(self):
   var_list = snt.get_variables_in_module(self)
   b = [x for x in var_list if self._raw_name(x.name) == "b"]
   assert len(b) == 1
   return b[0]
Beispiel #15
0
 def w(self):
   var_list = snt.get_variables_in_module(self)
   w = [x for x in var_list if self._raw_name(x.name) == "w"]
   assert len(w) == 1
   return w[0]
Beispiel #16
0
  def meta_loss(self,
                make_loss,
                len_unroll,
                net_assignments=None,
                second_derivatives=False):
    """Returns an operator computing the meta-loss.

    Args:
      make_loss: Callable which returns the optimizee loss; note that this
          should create its ops in the default graph.
      len_unroll: Number of steps to unroll.
      net_assignments: variable to optimizer mapping. If not None, it should be
          a list of (k, names) tuples, where k is a valid key in the kwargs
          passed at at construction time and names is a list of variable names.
      second_derivatives: Use second derivatives (default is false).

    Returns:
      namedtuple containing (loss, update, reset, fx, x)
    """

    # Construct an instance of the problem only to grab the variables. This
    # loss will never be evaluated.
    x, constants = _get_variables(make_loss)

    print("Optimizee variables")
    print([op.name for op in x])
    print("Problem variables")
    print([op.name for op in constants])

    # Create the optimizer networks and find the subsets of variables to assign
    # to each optimizer.
    nets, net_keys, subsets = _make_nets(x, self._config, net_assignments)

    # Store the networks so we can save them later.
    self._nets = nets

    # Create hidden state for each subset of variables.
    state = []
    with tf.name_scope("states"):
      for i, (subset, key) in enumerate(zip(subsets, net_keys)):
        net = nets[key]
        with tf.name_scope("state_{}".format(i)):
          state.append([net.initial_state_for_inputs(x[j], dtype=tf.float32) for j in subset])

    def update(net, fx, x, state):
      """Parameter and RNN state update."""
      with tf.name_scope("gradients"):
        gradients = tf.gradients(fx, x)

        # Stopping the gradient here corresponds to what was done in the
        # original L2L NIPS submission. However it looks like things like
        # BatchNorm, etc. don't support second-derivatives so we still need
        # this term.
        if not second_derivatives:
          gradients = [tf.stop_gradient(g) for g in gradients]

      with tf.name_scope("deltas"):
        deltas, state_next = zip(*[net(g, s) for g, s in zip(gradients, state)])
        state_next = list(state_next)

      return deltas, state_next

    def time_step(t, fx_array, x, state):
      """While loop body."""
      x_next = list(x)
      state_next = []

      with tf.name_scope("fx"):
        fx = _make_with_custom_variables(make_loss, x)
        fx_array = fx_array.write(t, fx)

      with tf.name_scope("dx"):
        for subset, key, s_i in zip(subsets, net_keys, state):
          x_i = [x[j] for j in subset]
          deltas, s_i_next = update(nets[key], fx, x_i, s_i)

          for idx, j in enumerate(subset):
            x_next[j] += deltas[idx]
          state_next.append(s_i_next)

      with tf.name_scope("t_next"):
        t_next = t + 1

      return t_next, fx_array, x_next, state_next

    # Define the while loop.
    fx_array = tf.TensorArray(tf.float32, size=len_unroll + 1,
                              clear_after_read=False)
    _, fx_array, x_final, s_final = tf.while_loop(
        cond=lambda t, *_: t < len_unroll,
        body=time_step,
        loop_vars=(0, fx_array, x, state),
        parallel_iterations=1,
        swap_memory=True,
        name="unroll")

    with tf.name_scope("fx"):
      fx_final = _make_with_custom_variables(make_loss, x_final)
      fx_array = fx_array.write(len_unroll, fx_final)

    loss = tf.reduce_sum(fx_array.stack(), name="loss")

    # Reset the state; should be called at the beginning of an epoch.
    with tf.name_scope("reset"):
      variables = (nest.flatten(_nested_variable(state)) + x + constants)
      #variables = (nest.flatten(state) +
      #             x + constants)
      # Empty array as part of the reset process.
      reset = [tf.variables_initializer(variables), fx_array.close()]

    # Operator to update the parameters and the RNN state after our loop, but
    # during an epoch.
    
    with tf.name_scope("update"):
      update = (nest.flatten(_nested_assign(x, x_final)) + nest.flatten(_nested_assign(_nested_variable(state), s_final)))

    # Log internal variables.
    for k, net in nets.items():
      print("Optimizer '{}' variables".format(k))
      print([op.name for op in snt.get_variables_in_module(net)])

    return MetaLoss(loss, update, reset, fx_final, x_final)
Beispiel #17
0
    def meta_loss(self,
                  make_loss,
                  len_unroll,
                  net_assignments=None,
                  second_derivatives=False):
        """Returns an operator computing the meta-loss.

        Args:
          make_loss: Callable which returns the optimizee loss; note that this
              should create its ops in the default graph.
          len_unroll: Number of steps to unroll.
          net_assignments: variable to optimizer mapping. If not None, it should be
              a list of (k, names) tuples, where k is a valid key in the kwargs
              passed at at construction time and names is a list of variable names.
          second_derivatives: Use second derivatives (default is false).

        Returns:
          namedtuple containing (loss, update, reset, fx, x)
        """

        # Construct an instance of the problem only to grab the variables. This
        # loss will never be evaluated.
        x = []
        constants = []
        for a in make_loss.values():
            item1, item2 = util._get_variables(a)
            x.append(item1)
            constants.append(item2)
        num_var = nest.flatten([0, [len(a) for a in x]])
        var_num = np.cumsum(num_var)
        x = nest.flatten(x)
        constants = nest.flatten(constants)
        print("Optimizee variables")
        print([op.name for op in x])
        print("Problem variables")
        print([op.name for op in constants])

        # Create the optimizer networks and find the subsets of variables to assign
        # to each optimizer.
        nets, net_keys, subsets = util._make_nets(x, self._config,
                                                  net_assignments)

        # Store the networks so we can save them later.
        self._nets = nets

        # Create hidden state for each subset of variables.
        state = []
        with tf.name_scope("states"):
            for i, (subset, key) in enumerate(zip(subsets, net_keys)):
                net = nets[key]
                with tf.name_scope("state_{}".format(i)):
                    state.append(
                        util._nested_variable([
                            net.initial_state_for_inputs(x[j],
                                                         dtype=tf.float32)
                            for j in subset
                        ],
                                              name="state",
                                              trainable=False))

        def update(net, fx, x, state, subset):
            """Parameter and RNN state update."""
            with tf.name_scope("gradients"):
                if len(subset) == sum(num_var):
                    gradients = nest.flatten([
                        tf.gradients(fx[a],
                                     x[num_var[b]:num_var[b] + num_var[b + 1]])
                        for a, b in zip(range(len(fx)), range(
                            len(num_var) - 1))
                    ])
                else:
                    bin_num = np.digitize(subset, var_num) - 1
                    if np.std(bin_num) == 0 or len(bin_num) == 1:
                        gradients = nest.flatten(
                            [tf.gradients(fx[bin_num[0]], x)])
                    else:
                        gradients = nest.flatten([
                            tf.gradients(fx[a], x[b])
                            for a, b in zip(bin_num, range(len(x)))
                        ])
                if not second_derivatives:
                    gradients = [tf.stop_gradient(g) for g in gradients]

            with tf.name_scope("deltas"):
                deltas, state_next = zip(
                    *[net(g, s) for g, s in zip(gradients, state)])
                state_next = list(state_next)

            ratio = sum([
                tf.reduce_mean(tf.div(d, g))
                for d, g in zip(deltas, gradients)
            ]) / len(gradients)

            return deltas, state_next, ratio

        def time_step(t, fx_array, fx_array_opt, lr_optimizee, x, state):
            """While loop body."""
            x_next = list(x)
            state_next = []
            ratio = []

            with tf.name_scope("fx"):
                fx = [
                    util._make_with_custom_variables(
                        a, x[num_var[b]:num_var[b] + num_var[b + 1]])
                    for a, b in zip(make_loss.values(), range(
                        len(num_var) - 1))
                ]

            with tf.name_scope("fx_sum"):
                fxsum = sum(fx[a] for a in range(len(fx)))
                fx_array = fx_array.write(t, fxsum)

            with tf.name_scope("fx_opt"):
                fxopt = fx[0]
                fx_array_opt = fx_array_opt.write(t, fxopt)

            with tf.name_scope("dx"):
                for subset, key, s_i in zip(subsets, net_keys, state):
                    x_i = [x[j] for j in subset]
                    deltas, s_i_next, ratio_i = update(nets[key], fx, x_i, s_i,
                                                       subset)
                    for idx, j in enumerate(subset):
                        x_next[j] += deltas[idx]
                    state_next.append(s_i_next)
                    ratio.append(ratio_i)

            with tf.name_scope("lr_opt"):
                lr_optimizee = lr_optimizee.write(t, sum(ratio) / len(ratio))

            with tf.name_scope("t_next"):
                t_next = t + 1

            return t_next, fx_array, fx_array_opt, lr_optimizee, x_next, state_next

        # Define the while loop.
        fx_array = tf.TensorArray(tf.float32,
                                  size=len_unroll,
                                  clear_after_read=False)
        fx_array_opt = tf.TensorArray(tf.float32,
                                      size=len_unroll,
                                      clear_after_read=False)
        lr_optimizee = tf.TensorArray(tf.float32,
                                      size=len_unroll - 1,
                                      clear_after_read=False)
        _, fx_array, fx_array_opt, lr_optimizee, x_final, s_final = tf.while_loop(
            cond=lambda t, *_: t < len_unroll - 1,
            body=time_step,
            loop_vars=(0, fx_array, fx_array_opt, lr_optimizee, x, state),
            parallel_iterations=1,
            swap_memory=True,
            name="unroll")

        with tf.name_scope("fx"):
            fx_final = [
                util._make_with_custom_variables(
                    a, x_final[num_var[b]:num_var[b] + num_var[b + 1]])
                for a, b in zip(make_loss.values(), range(len(num_var) - 1))
            ]

        with tf.name_scope("fx_sum"):
            fxsum = sum(fx_final[a] for a in range(len(fx_final)))
            fx_array = fx_array.write(len_unroll - 1, fxsum)

        with tf.name_scope("fx_opt"):
            fxopt = fx_final[0]
            fx_array_opt = fx_array_opt.write(len_unroll - 1, fxopt)
            farray = fx_array_opt.stack()

        with tf.name_scope("lr_opt"):
            lr_opt = lr_optimizee.stack()

        loss = tf.reduce_sum(fx_array.stack(), name="loss")

        # Reset the state; should be called at the beginning of an epoch.
        with tf.name_scope("reset"):
            variables = (nest.flatten(state) + x + constants)
            # Empty array as part of the reset process.
            reset = [
                tf.variables_initializer(variables),
                fx_array.close(),
                fx_array_opt.close(),
                lr_optimizee.close()
            ]

        # Operator to update the parameters and the RNN state after our loop, but
        # during an epoch.
        with tf.name_scope("update"):
            update = (nest.flatten(util._nested_assign(x, x_final)) +
                      nest.flatten(util._nested_assign(state, s_final)))

        # Log internal variables.
        for k, net in nets.items():
            print("Optimizer '{}' variables".format(k))
            print([op.name for op in snt.get_variables_in_module(net)])

        return MetaLoss(loss, update, reset, fxopt, farray, lr_opt, x_final)
Beispiel #18
0
 def b(self):
     var_list = snt.get_variables_in_module(self)
     assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
     b = [x for x in var_list if self._raw_name(x.name) == "b"]
     assert len(b) == 1
     return b[0]
Beispiel #19
0
 def w(self):
     var_list = snt.get_variables_in_module(self)
     w = [x for x in var_list if self._raw_name(x.name) == "w"]
     assert len(w) == 1
     return w[0]
Beispiel #20
0
def main(unused_argv):

    if FLAGS.data_dir == '' or not os.path.exists(FLAGS.data_dir):
        raise ValueError('invalid data directory')

    if FLAGS.evaluate:
        print("evaluate the model")
        data_path = os.path.join(FLAGS.data_dir, 'eeg-xval.tfr')
    else:
        print("model inference")
        data_path = os.path.join(FLAGS.data_dir, 'eeg-test.tfr')

    if FLAGS.output_dir == '' or not os.path.exists(FLAGS.output_dir):
        raise ValueError('invalid output directory {}'.format(FLAGS.output_dir))

    checkpoint_dir = os.path.join(FLAGS.output_dir, '')

    print('reconstructing models and inputs.')
    input_ = Input(1, FLAGS.num_points)

    waves, labels = input_(data_path)

    if FLAGS.adp:
        adaptor = Adaptor()
        classifier = ReducedClassifier()

        logits = adaptor(waves)
        logits = classifier(logits)
    else:

        classifier = Classifier(FLAGS.num_points, FLAGS.sampling_rate)
        logits = classifier(waves, expand_dims = True)

    # Calculate the loss of the model.
    logits = tf.argmax(logits, axis = -1)
    
    metrics = Metrics("accuracy")
    with tf.control_dependencies([tf.assert_equal(tf.rank(labels), tf.rank(logits))]):
        metric_op, metric_update_op = metrics(labels, logits)
   
    if FLAGS.adp:
        variables = snt.get_variables_in_module(adaptor) + snt.get_variables_in_module(classifier)
        saver_adaptor = tf.train.Saver(snt.get_variables_in_module(adaptor))
        saver = tf.train.Saver(variables)
    else:
        saver = tf.train.Saver(snt.get_variables_in_module(classifier))

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            # Assuming model_checkpoint_path looks something like:
            #   /my-favorite-path/cifar10_train/model.ckpt-0,
            # extract global_step from it.
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
            print('No checkpoint file found')
            return

        assert (FLAGS.test_size > 0), "invalid test samples"
        for i in range(FLAGS.test_size):
            sess.run(metric_update_op)

        metric = sess.run(metric_op)
        print("metric -> {}".format(metric))
Beispiel #21
0
def train(train_log_dir, checkpoint, eval_every_n_steps=10, num_steps=3000):
  dataset_fn = datasets.mnist.TinyMnist
  w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
  theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess

  meta_objectives = []
  meta_objectives.append(
      meta_objective.linear_regression.LinearRegressionMetaObjective)
  meta_objectives.append(meta_objective.sklearn.LogisticRegression)

  checkpoint_vars, train_one_step_op, (
      base_model, dataset) = evaluation.construct_evaluation_graph(
          theta_process_fn=theta_process_fn,
          w_learner_fn=w_learner_fn,
          dataset_fn=dataset_fn,
          meta_objectives=meta_objectives)
  batch = dataset()
  pre_logit, outputs = base_model(batch)

  global_step = tf.train.get_or_create_global_step()
  var_list = list(
      snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES))

  tf.logging.info("all vars")
  for v in tf.all_variables():
    tf.logging.info("   %s" % str(v))
  global_step = tf.train.get_global_step()
  accumulate_global_step = global_step.assign_add(1)
  reset_global_step = global_step.assign(0)

  train_op = tf.group(
      train_one_step_op, accumulate_global_step, name="train_op")

  summary_op = tf.summary.merge_all()

  file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
  if checkpoint:
    str_var_list = checkpoint_utils.list_variables(checkpoint)
    name_to_v_map = {v.op.name: v for v in tf.all_variables()}
    var_list = [
        name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
    ]
    saver = tf.train.Saver(var_list)
    missed_variables = [
        v.op.name for v in set(
            snt.get_variables_in_scope("LocalWeightUpdateProcess",
                                       tf.GraphKeys.GLOBAL_VARIABLES)) -
        set(var_list)
    ]
    assert len(missed_variables) == 0, "Missed a theta variable."

  hooks = []

  with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:

    # global step should be restored from the evals job checkpoint or zero for fresh.
    step = sess.run(global_step)

    if step == 0 and checkpoint:
      tf.logging.info("force restore")
      saver.restore(sess, checkpoint)
      tf.logging.info("force restore done")
      sess.run(reset_global_step)
      step = sess.run(global_step)

    while step < num_steps:
      if step % eval_every_n_steps == 0:
        s, _, step = sess.run([summary_op, train_op, global_step])
        file_writer.add_summary(s, step)
      else:
        _, step = sess.run([train_op, global_step])
Beispiel #22
0
def train(train_log_dir,
          checkpoint_dir,
          eval_every_n_steps=10,
          num_steps=3000):
    dataset_fn = datasets.mnist.TinyMnist
    w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
    theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess

    meta_objectives = []
    meta_objectives.append(
        meta_objective.linear_regression.LinearRegressionMetaObjective)
    meta_objectives.append(meta_objective.sklearn.LogisticRegression)

    checkpoint_vars, train_one_step_op, (
        base_model, dataset) = evaluation.construct_evaluation_graph(
            theta_process_fn=theta_process_fn,
            w_learner_fn=w_learner_fn,
            dataset_fn=dataset_fn,
            meta_objectives=meta_objectives)
    batch = dataset()
    pre_logit, outputs = base_model(batch)

    global_step = tf.train.get_or_create_global_step()
    var_list = list(
        snt.get_variables_in_module(base_model,
                                    tf.GraphKeys.TRAINABLE_VARIABLES))

    tf.logging.info("all vars")
    for v in tf.all_variables():
        tf.logging.info("   %s" % str(v))
    global_step = tf.train.get_global_step()
    accumulate_global_step = global_step.assign_add(1)
    reset_global_step = global_step.assign(0)

    train_op = tf.group(train_one_step_op,
                        accumulate_global_step,
                        name="train_op")

    summary_op = tf.summary.merge_all()

    file_writer = summary_utils.LoggingFileWriter(train_log_dir,
                                                  regexes=[".*"])
    if checkpoint_dir:
        str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
        name_to_v_map = {v.op.name: v for v in tf.all_variables()}
        var_list = [
            name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
        ]
        saver = tf.train.Saver(var_list)
        missed_variables = [
            v.op.name for v in set(
                snt.get_variables_in_scope("LocalWeightUpdateProcess",
                                           tf.GraphKeys.GLOBAL_VARIABLES)) -
            set(var_list)
        ]
        assert len(missed_variables) == 0, "Missed a theta variable."

    hooks = []

    with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:

        # global step should be restored from the evals job checkpoint or zero for fresh.
        step = sess.run(global_step)

        if step == 0 and checkpoint_dir:
            tf.logging.info("force restore")
            saver.restore(sess, checkpoint_dir)
            tf.logging.info("force restore done")
            sess.run(reset_global_step)
            step = sess.run(global_step)

        while step < num_steps:
            if step % eval_every_n_steps == 0:
                s, _, step = sess.run([summary_op, train_op, global_step])
                file_writer.add_summary(s, step)
            else:
                _, step = sess.run([train_op, global_step])
Beispiel #23
0
  def meta_loss(self,
                make_loss,
                len_unroll,
                net_assignments=None,
                second_derivatives=False):
    """Returns an operator computing the meta-loss.

    Args:
      make_loss: Callable which returns the optimizee loss; note that this
          should create its ops in the default graph.
      len_unroll: Number of steps to unroll.
      net_assignments: variable to optimizer mapping. If not None, it should be
          a list of (k, names) tuples, where k is a valid key in the kwargs
          passed at at construction time and names is a list of variable names.
      second_derivatives: Use second derivatives (default is false).

    Returns:
      namedtuple containing (loss, update, reset, fx, x), ...
    """

    # Construct an instance of the problem only to grab the variables. This
    # loss will never be evaluated.
    # pdb.set_trace()
    
    x, constants = _get_variables(make_loss)

    print("Optimizee variables")
    print([op.name for op in x])
    print("Problem variables")
    print([op.name for op in constants])

    # create scale placeholder here
    scale = []
    for k in x:
      scale.append(tf.placeholder_with_default(tf.ones(shape=k.shape), shape=k.shape, name=k.name[:-2] + "_scale"))
    step = tf.placeholder(shape=(), name="step", dtype=tf.int32)

    # Create the optimizer networks and find the subsets of variables to assign
    # to each optimizer.
    nets, net_keys, subsets = _make_nets(x, self._config, net_assignments)
    print('nets', nets)
    print('subsets', subsets)
    # Store the networks so we can save them later.
    self._nets = nets

    # Create hidden state for each subset of variables.
    state = []
    with tf.name_scope("states"):
      for i, (subset, key) in enumerate(zip(subsets, net_keys)):
        net = nets[key]
        with tf.name_scope("state_{}".format(i)):
          state.append(_nested_variable(
              [net.initial_state_for_inputs(x[j], dtype=tf.float32)
               for j in subset],
              name="state", trainable=False))

    # m and v in adam
    state_mt = []
    state_vt = []
    for i, (subset, key) in enumerate(zip(subsets, net_keys)):
      mt = [tf.Variable(tf.zeros(shape=x[j].shape), name=x[j].name[:-2] + "_mt", dtype=tf.float32, trainable=False) for j in subset]
      vt = [tf.Variable(tf.zeros(shape=x[j].shape), name=x[j].name[:-2] + "_vt", dtype=tf.float32, trainable=False) for j in subset]
      state_mt.append(mt)
      state_vt.append(vt)

    def update(net, fx, x, state, mt, vt, t):
      """Parameter and RNN state update."""
      with tf.name_scope("gradients"):
        gradients = tf.gradients(fx, x)

        # Stopping the gradient here corresponds to what was done in the
        # original L2L NIPS submission. However it looks like things like
        # BatchNorm, etc. don't support second-derivatives so we still need
        # this term.
        if not second_derivatives:
          gradients = [tf.stop_gradient(g) for g in gradients]
        # update mt and vt
        mt_next = [self.beta1*m + (1.0-self.beta1)*g for m, g in zip(mt, gradients)]
        mt_hat = [m/(1-tf.pow(self.beta1, tf.cast(step+t, dtype=tf.float32))) for m in mt_next]
        vt_next = [self.beta2 * v + (1.0 - self.beta2) * g*g for v, g in zip(vt, gradients)]
        vt_hat = [v/(1 - tf.pow(self.beta2, tf.cast(step+t, dtype=tf.float32))) for v in vt_next]
        mt_tilde = [m / (tf.sqrt(v)+1e-8) for m, v in zip(mt_hat, vt_hat)]
        gt_tilde = [g / (tf.sqrt(v)+1e-8) for g, v in zip(gradients, vt_hat)]

      with tf.name_scope("deltas"):
        deltas, state_next = zip(*[net(m, g, s) for m, g, s in zip(mt_tilde, gt_tilde, state)])
        state_next = _nested_tuple(state_next)
        state_next = list(state_next)

      return deltas, state_next, mt_next, vt_next

    def time_step(t, fx_array, x, state, state_mt, state_vt):
      """While loop body."""
      x_next = list(x)
      state_next = []
      state_mt_next = []
      state_vt_next = []

      with tf.name_scope("fx"):
        scaled_x = [x[k] * scale[k] for k in range(len(scale))]
        fx = _make_with_custom_variables(make_loss, scaled_x)
        fx_array = fx_array.write(t, fx)

      with tf.name_scope("dx"):
        for subset, key, s_i, mt, vt in zip(subsets, net_keys, state, state_mt, state_vt):
          x_i = [x[j] for j in subset]
          deltas, s_i_next, mt_i_next, vt_i_next = update(nets[key], fx, x_i, s_i, mt, vt, t)
          for idx, j in enumerate(subset):
            delta = deltas[idx]
            x_next[j] += delta
          state_next.append(s_i_next)
          state_mt_next.append(mt_i_next)
          state_vt_next.append(vt_i_next)

      with tf.name_scope("t_next"):
        t_next = t + 1

      return t_next, fx_array, x_next, state_next, state_mt_next, state_vt_next

    # Define the while loop.
    fx_array = tf.TensorArray(tf.float32, size=len_unroll+1,
                              clear_after_read=False)
    _, fx_array, x_final, s_final, mt_final, vt_final = tf.while_loop(
        cond=lambda t, *_: t < len_unroll,
        body=time_step,
        loop_vars=(0, fx_array, x, state, state_mt, state_vt),
        parallel_iterations=1,
        swap_memory=True,
        name="unroll")

    with tf.name_scope("fx"):
      scaled_x_final = [x_final[k] * scale[k] for k in range(len(scale))]
      fx_final = _make_with_custom_variables(make_loss, scaled_x_final)
      fx_array = fx_array.write(len_unroll, fx_final)

    loss = tf.reduce_sum(fx_array.stack(), name="loss")

    # Reset the state; should be called at the beginning of an epoch.
    with tf.name_scope("reset"):
      variables = (nest.flatten(state) +
                   x + constants)
      reset_mt = [tf.assign(m, tf.zeros(shape=m.shape)) for mt in state_mt for m in mt]
      reset_vt = [tf.assign(v, tf.zeros(shape=v.shape)) for vt in state_vt for v in vt]

      # Empty array as part of the reset process.
      reset = [tf.variables_initializer(variables), fx_array.close()] + reset_mt + reset_vt

    # Operator to update the parameters and the RNN state after our loop, but
    # during an epoch.
    with tf.name_scope("update"):
      update = (nest.flatten(_nested_assign(x, x_final)) +
                nest.flatten(_nested_assign(state, s_final)) +
                nest.flatten(_nested_assign(state_mt, mt_final)) +
                nest.flatten(_nested_assign(state_vt, vt_final)))

    # Log internal variables.
    for k, net in nets.items():
      print("Optimizer '{}' variables".format(k))
      print([op for op in snt.get_variables_in_module(net)])

    return MetaLoss(loss, update, reset, fx_final, x_final), scale, x, step
Beispiel #24
0
    def meta_loss(self,
                  make_loss,
                  len_unroll,
                  net_assignments=None,
                  load_states=False,
                  second_derivatives=False):
        """Returns an operator computing the meta-loss.

    Args:
      make_loss: Callable which returns the optimizee loss; note that this
          should create its ops in the default graph.
      len_unroll: Number of steps to unroll.
      net_assignments: variable to optimizer mapping. If not None, it should be
          a list of (k, names) tuples, where k is a valid key in the kwargs
          passed at at construction time and names is a list of variable names.
      second_derivatives: Use second derivatives (default is false).

    Returns:
      namedtuple containing (loss, update, reset, fx, x)
    """

        # Construct an instance of the problem only to grab the variables. This
        # loss will never be evaluated.
        x, constants = _get_variables(make_loss)

        print("Optimizee variables")
        print([op.name for op in x])

        # Create the optimizer networks and find the subsets of variables to assign
        # to each optimizer.
        nets, net_keys, subsets = _make_nets(x, self._config, net_assignments)

        # Store the networks so we can save them later.
        self._nets = nets

        # Create hidden state for each subset of variables.
        state = []
        with tf.name_scope("states"):
            for i, (subset, key) in enumerate(zip(subsets, net_keys)):
                net = nets[key]
                with tf.name_scope("state_{}".format(i)):
                    state.append(
                        _nested_variable([
                            net.initial_state_for_inputs(x[j],
                                                         dtype=tf.float32)
                            for j in subset
                        ],
                                         name="state",
                                         trainable=False))
        self.init_state = state

        assign_ops = []
        if load_states:
            state_vars = reduce(lambda x, y: x + y, chain(*state[0]))
            state_arrs = reduce(lambda x, y: x + y, chain(*self._states[0]))

            for vv, aa in zip(state_vars, state_arrs):
                assign_ops += [tf.assign(vv, tf.convert_to_tensor(aa))]

        def update(net, fx, x, state):
            """Parameter and RNN state update."""
            with tf.name_scope("gradients"):
                gradients = tf.gradients(fx, x)

                # gradients_names = [g.name for g in gradients]
                # gradients_names = [
                #     name.split('mlp/')[1].split('/Reshape')[0].split('/MatMul_1')[0]
                #     for name in gradients_names
                # ]
                # print_values = ['grad']
                # for gg, name in zip(gradients, gradients_names):
                #   print_values.append(name)
                #   print_values.append(tf.reduce_mean(tf.abs(gg)))
                # dbg = tf.Print(tf.constant(0.0), print_values, summarize=100)
                # with tf.control_dependencies([dbg]):
                #   gradients = [tf.identity(g) for g in gradients]

                # Stopping the gradient here corresponds to what was done in the
                # original L2L NIPS submission. However it looks like things like
                # BatchNorm, etc. don't support second-derivatives so we still need
                # this term.
                if not second_derivatives:
                    gradients = [tf.stop_gradient(g) for g in gradients]

            with tf.name_scope("deltas"):
                deltas, state_next = zip(
                    *[net(g, s) for g, s in zip(gradients, state)])
                deltas = [d for d in deltas]
                #   print_values = ['delta']
                #   for dd, name in zip(deltas, gradients_names):
                #     print_values.append(name)
                #     print_values.append(tf.reduce_mean(tf.abs(dd)))
                #   dbg = tf.Print(tf.constant(0.0), print_values, summarize=100)

                #   with tf.control_dependencies([dbg]):
                #     deltas = [tf.identity(d) for d in deltas]
                state_next = list(state_next)

            # compute the "learning rate" by delta/gradient
            grad_vec = tf.concat([tf.reshape(gg, [-1]) for gg in gradients],
                                 axis=0)
            delta_vec = tf.concat([tf.reshape(dd, [-1]) for dd in deltas],
                                  axis=0)

            # dominant_grad_idx = tf.argmax(tf.abs(delta_vec))
            dominant_grad_idx = 4176

            delta_vec_norm = tf.sqrt(tf.reduce_sum(delta_vec * delta_vec))
            grad_vec_norm = tf.sqrt(tf.reduce_sum(grad_vec * grad_vec))

            # lr = tf.div(tf.abs(delta_vec), tf.abs(grad_vec) + tf.constant(1.0e-16))
            # lr = tf.div(delta_vec_norm, grad_vec_norm + tf.constant(1.0e-16))
            lr = tf.div(
                tf.abs(delta_vec[dominant_grad_idx]),
                tf.abs(grad_vec[dominant_grad_idx]) + tf.constant(1.0e-16))

            log_lr = tf.log(lr)
            # dbg = tf.Print(tf.constant(0.0), [delta_vec_norm])
            # dbg = tf.Print(tf.constant(0.0), [dominant_grad_idx])
            # dbg = tf.Print(tf.constant(0.0), [lr])
            dbg = tf.constant(0.0)

            with tf.control_dependencies([dbg, log_lr]):
                deltas = [tf.identity(d) for d in deltas]

            tf.summary.scalar("learning_rate", lr)

            # tf.summary.histogram("learning_rate", lr)
            # tf.summary.histogram("log_learning_rate", log_lr)

            return deltas, state_next

        def time_step(t, fx_array, x, state):
            """While loop body."""
            x_next = list(x)
            state_next = []

            with tf.name_scope("fx"):
                fx = _make_with_custom_variables(make_loss, x)
                fx_array = fx_array.write(t, fx)

            with tf.name_scope("dx"):
                for subset, key, s_i in zip(subsets, net_keys, state):
                    x_i = [x[j] for j in subset]
                    deltas, s_i_next = update(nets[key], fx, x_i, s_i)

                    for idx, j in enumerate(subset):
                        x_next[j] += deltas[idx]
                    state_next.append(s_i_next)

            with tf.name_scope("t_next"):
                t_next = t + 1

            return t_next, fx_array, x_next, state_next

        with tf.control_dependencies(assign_ops):
            # Define the while loop.
            fx_array = tf.TensorArray(tf.float32,
                                      size=len_unroll + 1,
                                      clear_after_read=False)
            _, fx_array, x_final, s_final = tf.while_loop(
                cond=lambda t, *_: t < len_unroll,
                body=time_step,
                loop_vars=(0, fx_array, x, state),
                parallel_iterations=1,
                swap_memory=True,
                name="unroll")

            with tf.name_scope("fx"):
                fx_final = _make_with_custom_variables(make_loss, x_final)
                fx_array = fx_array.write(len_unroll, fx_final)

            loss = tf.reduce_sum(fx_array.stack(), name="loss")

            # Reset the state; should be called at the beginning of an epoch.
            with tf.name_scope("reset"):
                # variables = (nest.flatten(state) + x + constants)
                variables = nest.flatten(state)
                # Empty array as part of the reset process.
                reset = [tf.variables_initializer(variables), fx_array.close()]

            # Operator to update the parameters and the RNN state after our loop, but
            # during an epoch.
            with tf.name_scope("update"):
                update = (nest.flatten(_nested_assign(x, x_final)) +
                          nest.flatten(_nested_assign(state, s_final)))

            # Log internal variables.
            for k, net in nets.items():
                print("Optimizer '{}' variables".format(k))
                print([op.name for op in snt.get_variables_in_module(net)])

        return MetaLoss(loss, update, reset, fx_final, x_final, s_final)
Beispiel #25
0
 def b(self):
   var_list = snt.get_variables_in_module(self)
   assert len(var_list) == 2, "Found not 2 but %d" % len(var_list)
   b = [x for x in var_list if self._raw_name(x.name) == "b"]
   assert len(b) == 1
   return b[0]
Beispiel #26
0
 def b(self):
     var_list = snt.get_variables_in_module(self)
     b = [x for x in var_list if self._raw_name(x.name) == "b"]
     assert len(b) == 1
     return b[0]
Beispiel #27
0
def get_variables_in_modules(module_list):
  var_list = []
  for m in module_list:
    var_list.extend(snt.get_variables_in_module(m))
  return var_list
Beispiel #28
0
def get_variables_in_modules(module_list):
  var_list = []
  for m in module_list:
    var_list.extend(snt.get_variables_in_module(m))
  return var_list
Beispiel #29
0
 def get_trainable_vars(self):
     return snt.get_variables_in_module(self)
Beispiel #30
0
def main(unused_argv):
    summ = Summaries()

    if FLAGS.data_dir == '' or not os.path.exists(FLAGS.data_dir):
        raise ValueError('invalid data directory {}'.format(FLAGS.data_dir))

    train_data_path = os.path.join(FLAGS.data_dir, 'eeg-train.tfr')
    xval_data_path = os.path.join(FLAGS.data_dir, 'eeg-test.tfr')

    if FLAGS.output_dir == '':
        raise ValueError('invalid output directory {}'.format(
            FLAGS.output_dir))
    elif not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    if FLAGS.checkpoint_dir == '':
        raise ValueError('invalid checkpoint directory {}'.format(
            FLAGS.output_dir))

    event_log_dir = os.path.join(FLAGS.output_dir, '')

    checkpoint_path = os.path.join(FLAGS.output_dir, 'model.ckpt')

    print('Constructing models.')

    if FLAGS.adp:
        adaptor = Adaptor()
        classifier = ReducedClassifier()

        train_loss, train_op, train_summ_op = \
            train(train_data_path, adaptor, classifier, summ)
        xval_op, xval_summ_op = xval(xval_data_path, adaptor, classifier, summ)
    else:
        classifier = Classifier(FLAGS.num_points, FLAGS.sampling_rate)

        train_loss, train_op, train_summ_op = \
            train(train_data_path, None, classifier, summ)
        xval_op, xval_summ_op = xval(xval_data_path, None, classifier, summ)

    print('Constructing saver.')

    if FLAGS.adp:
        variables = snt.get_variables_in_module(
            adaptor) + snt.get_variables_in_module(classifier)
        saver_adaptor = tf.train.Saver(snt.get_variables_in_module(adaptor))
        saver = tf.train.Saver(variables)
    else:
        saver = tf.train.Saver(snt.get_variables_in_module(classifier))

    # Start running operations on the Graph. allow_soft_placement must be set to
    # True to as some of the ops do not have GPU implementations.
    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)

    assert (FLAGS.gpus != ''), 'invalid GPU specification'
    config.gpu_options.visible_device_list = FLAGS.gpus

    # Build an initialization operation to run below.
    init = [
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    ]

    with tf.Session(config=config) as sess:
        sess.run(init)

        if FLAGS.adp:
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                # Restores from checkpoint
                saver_adaptor.restore(sess, ckpt.model_checkpoint_path)
                # Assuming model_checkpoint_path looks something like:
                #   /my-favorite-path/cifar10_train/model.ckpt-0,
                # extract global_step from it.
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                    '-')[-1]
            else:
                print('No checkpoint file found')
                return

        writer = tf.summary.FileWriter(event_log_dir, graph=sess.graph)

        # Run training.
        for itr in range(FLAGS.num_iterations):
            cost, _, train_summ_str = sess.run(
                [train_loss, train_op, train_summ_op])
            # Print info: iteration #, cost.
            print(str(itr) + ' ' + str(cost))

            if itr % FLAGS.validation_interval == 1:
                # Run through validation set.
                sess.run(xval_op)
                val_summ_str = sess.run(xval_summ_op)
                writer.add_summary(val_summ_str, itr)
                reset_metrics(sess)

            if itr % FLAGS.summary_interval == 1:
                writer.add_summary(train_summ_str, itr)

        # coord.request_stop()
        # coord.join(threads)

        tf.logging.info('Saving model.')
        saver.save(sess, checkpoint_path)
        tf.logging.info('Training complete')