Exemple #1
0
class Er(ContinualModel):
    NAME = 'er'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Er, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def observe(self, inputs, labels, not_aug_inputs):

        real_batch_size = inputs.shape[0]

        self.opt.zero_grad()
        if not self.buffer.is_empty():
            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            inputs = torch.cat((inputs, buf_inputs))
            labels = torch.cat((labels, buf_labels))

        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()

        self.buffer.add_data(examples=not_aug_inputs,
                             labels=labels[:real_batch_size])

        return loss.item()
Exemple #2
0
class Der(ContinualModel):
    NAME = 'der'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Der, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def observe(self, inputs, labels, not_aug_inputs):

        self.opt.zero_grad()

        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)

        if not self.buffer.is_empty():
            buf_inputs, buf_logits = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            buf_outputs = self.net(buf_inputs)
            loss += self.args.alpha * F.mse_loss(buf_outputs, buf_logits)

        loss.backward()
        self.opt.step()
        self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data)

        return loss.item()
Exemple #3
0
class Mer(ContinualModel):
    NAME = 'mer'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Mer, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def draw_batches(self, inp, lab):
        batches = []
        for i in range(self.args.batch_num):
            if not self.buffer.is_empty():
                buf_inputs, buf_labels = self.buffer.get_data(
                    self.args.minibatch_size, transform=self.transform)
                inputs = torch.cat((buf_inputs, inp.unsqueeze(0)))
                labels = torch.cat(
                    (buf_labels, torch.tensor([lab]).to(self.device)))
                batches.append((inputs, labels))
            else:
                batches.append(
                    (inp.unsqueeze(0),
                     torch.tensor([lab]).unsqueeze(0).to(self.device)))
        return batches

    def observe(self, inputs, labels, not_aug_inputs):

        batches = self.draw_batches(inputs, labels)
        theta_A0 = self.net.get_params().data.clone()

        for i in range(self.args.batch_num):
            theta_Wi0 = self.net.get_params().data.clone()

            batch_inputs, batch_labels = batches[i]

            # within-batch step
            self.opt.zero_grad()
            outputs = self.net(batch_inputs)
            loss = self.loss(outputs, batch_labels.squeeze(-1))
            loss.backward()
            self.opt.step()

            # within batch reptile meta-update
            new_params = theta_Wi0 + self.args.beta * (self.net.get_params() -
                                                       theta_Wi0)
            self.net.set_params(new_params)

        self.buffer.add_data(examples=not_aug_inputs.unsqueeze(0),
                             labels=labels)

        # across batch reptile meta-update
        new_new_params = theta_A0 + self.args.gamma * (self.net.get_params() -
                                                       theta_A0)
        self.net.set_params(new_new_params)

        return loss.item()
Exemple #4
0
class AGem(ContinualModel):
    NAME = 'agem'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(AGem, self).__init__(backbone, loss, args, transform)

        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.transform = transform if self.args.iba else None

    def end_task(self, dataset):
        samples_per_task = self.args.buffer_size // dataset.N_TASKS
        loader = dataset.not_aug_dataloader(self.args, samples_per_task)
        cur_x, cur_y = next(iter(loader))[:2]
        self.buffer.add_data(examples=cur_x.to(self.device),
                             labels=cur_y.to(self.device))

    def observe(self, inputs, labels, not_aug_inputs):

        self.zero_grad()
        p = self.net.forward(inputs)
        loss = self.loss(p, labels)
        loss.backward()

        if not self.buffer.is_empty():
            store_grad(self.parameters, self.grad_xy, self.grad_dims)

            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            self.net.zero_grad()
            buf_outputs = self.net.forward(buf_inputs)
            penalty = self.loss(buf_outputs, buf_labels)
            penalty.backward()
            store_grad(self.parameters, self.grad_er, self.grad_dims)

            dot_prod = torch.dot(self.grad_xy, self.grad_er)
            if dot_prod.item() < 0:
                g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
                overwrite_grad(self.parameters, g_tilde, self.grad_dims)
            else:
                overwrite_grad(self.parameters, self.grad_xy, self.grad_dims)

        self.opt.step()

        return loss.item()
Exemple #5
0
class AGemr(ContinualModel):
    NAME = 'agem_r'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(AGemr, self).__init__(backbone, loss, args, transform)

        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.current_task = 0

    def observe(self, inputs, labels, not_aug_inputs):
        self.zero_grad()
        p = self.net.forward(inputs)
        loss = self.loss(p, labels)
        loss.backward()

        if not self.buffer.is_empty():
            store_grad(self.parameters, self.grad_xy, self.grad_dims)

            buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size)
            self.net.zero_grad()
            buf_outputs = self.net.forward(buf_inputs)
            penalty = self.loss(buf_outputs, buf_labels)
            penalty.backward()
            store_grad(self.parameters, self.grad_er, self.grad_dims)

            dot_prod = torch.dot(self.grad_xy, self.grad_er)
            if dot_prod.item() < 0:
                g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
                overwrite_grad(self.parameters, g_tilde, self.grad_dims)
            else:
                overwrite_grad(self.parameters, self.grad_xy, self.grad_dims)

        self.opt.step()

        self.buffer.add_data(examples=not_aug_inputs, labels=labels)

        return loss.item()
Exemple #6
0
class Gem(ContinualModel):
    NAME = 'gem'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(Gem, self).__init__(backbone, loss, args, transform)
        self.current_task = 0
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.transform = transform

        # Allocate temporary synaptic memory
        self.grad_dims = []
        for pp in self.parameters():
            self.grad_dims.append(pp.data.numel())

        self.grads_cs = []
        self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device)
        self.transform = transform if self.args.iba else None

    def end_task(self, dataset):
        self.current_task += 1
        self.grads_cs.append(
            torch.zeros(np.sum(self.grad_dims)).to(self.device))

        # add data to the buffer
        samples_per_task = self.args.buffer_size // dataset.N_TASKS

        loader = dataset.not_aug_dataloader(self.args, samples_per_task)
        cur_x, cur_y = next(iter(loader))[:2]
        self.buffer.add_data(
            examples=cur_x.to(self.device),
            labels=cur_y.to(self.device),
            task_labels=torch.ones(samples_per_task, dtype=torch.long).to(
                self.device) * (self.current_task - 1))

    def observe(self, inputs, labels, not_aug_inputs):

        if not self.buffer.is_empty():
            buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data(
                self.args.buffer_size, transform=self.transform)

            for tt in buf_task_labels.unique():
                # compute gradient on the memory buffer
                self.opt.zero_grad()
                cur_task_inputs = buf_inputs[buf_task_labels == tt]
                cur_task_labels = buf_labels[buf_task_labels == tt]

                for i in range(
                        math.ceil(len(cur_task_inputs) /
                                  self.args.batch_size)):
                    cur_task_outputs = self.forward(
                        cur_task_inputs[i * self.args.batch_size:(i + 1) *
                                        self.args.batch_size])
                    penalty = self.loss(
                        cur_task_outputs,
                        cur_task_labels[i * self.args.batch_size:(i + 1) *
                                        self.args.batch_size],
                        reduction='sum') / cur_task_inputs.shape[0]
                    penalty.backward()
                store_grad(self.parameters, self.grads_cs[tt], self.grad_dims)

                # cur_task_outputs = self.forward(cur_task_inputs)
                # penalty = self.loss(cur_task_outputs, cur_task_labels)
                # penalty.backward()
                # store_grad(self.parameters, self.grads_cs[tt], self.grad_dims)

        # now compute the grad on the current data
        self.opt.zero_grad()
        outputs = self.forward(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()

        # check if gradient violates buffer constraints
        if not self.buffer.is_empty():
            # copy gradient
            store_grad(self.parameters, self.grads_da, self.grad_dims)

            dot_prod = torch.mm(self.grads_da.unsqueeze(0),
                                torch.stack(self.grads_cs).T)
            if (dot_prod < 0).sum() != 0:
                project2cone2(self.grads_da.unsqueeze(1),
                              torch.stack(self.grads_cs).T,
                              margin=self.args.gamma)
                # copy gradients back
                overwrite_grad(self.parameters, self.grads_da, self.grad_dims)

        self.opt.step()

        return loss.item()
Exemple #7
0
class HAL(ContinualModel):
    NAME = 'hal'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(HAL, self).__init__(backbone, loss, args, transform)
        self.task_number = 0
        self.buffer = Buffer(self.args.buffer_size,
                             self.device,
                             get_dataset(args).N_TASKS,
                             mode='ring')
        self.hal_lambda = args.hal_lambda
        self.beta = args.beta
        self.gamma = args.gamma
        self.anchor_optimization_steps = 100
        self.finetuning_epochs = 1
        self.dataset = get_dataset(args)
        self.spare_model = self.dataset.get_backbone()
        self.spare_model.to(self.device)
        self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr)

    def end_task(self, dataset):
        self.task_number += 1
        # ring buffer mgmt (if we are not loading
        if self.task_number > self.buffer.task_number:
            self.buffer.num_seen_examples = 0
            self.buffer.task_number = self.task_number
        # get anchors (provided that we are not loading the model
        if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK:
            self.get_anchors(dataset)
            del self.phi

    def get_anchors(self, dataset):
        theta_t = self.net.get_params().detach().clone()
        self.spare_model.set_params(theta_t)

        # fine tune on memory buffer
        for _ in range(self.finetuning_epochs):
            inputs, labels = self.buffer.get_data(self.args.batch_size,
                                                  transform=self.transform)
            self.spare_opt.zero_grad()
            out = self.spare_model(inputs)
            loss = self.loss(out, labels)
            loss.backward()
            self.spare_opt.step()

        theta_m = self.spare_model.get_params().detach().clone()

        classes_for_this_task = np.unique(dataset.train_loader.dataset.targets)

        for a_class in classes_for_this_task:
            e_t = torch.rand(self.input_shape,
                             requires_grad=True,
                             device=self.device)
            e_t_opt = SGD([e_t], lr=self.args.lr)
            print(file=sys.stderr)
            for i in range(self.anchor_optimization_steps):
                e_t_opt.zero_grad()
                cum_loss = 0

                self.spare_opt.zero_grad()
                self.spare_model.set_params(theta_m.detach().clone())
                loss = -torch.sum(
                    self.loss(self.spare_model(e_t.unsqueeze(0)),
                              torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_opt.zero_grad()
                self.spare_model.set_params(theta_t.detach().clone())
                loss = torch.sum(
                    self.loss(self.spare_model(e_t.unsqueeze(0)),
                              torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_opt.zero_grad()
                loss = torch.sum(self.gamma *
                                 (self.spare_model.features(e_t.unsqueeze(0)) -
                                  self.phi)**2)
                assert not self.phi.requires_grad
                loss.backward()
                cum_loss += loss.item()

                e_t_opt.step()

            e_t = e_t.detach()
            e_t.requires_grad = False
            self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0)))
            del e_t
            print('Total anchors:', len(self.anchors), file=sys.stderr)

        self.spare_model.zero_grad()

    def observe(self, inputs, labels, not_aug_inputs):
        real_batch_size = inputs.shape[0]
        if not hasattr(self, 'input_shape'):
            self.input_shape = inputs.shape[1:]
        if not hasattr(self, 'anchors'):
            self.anchors = torch.zeros(tuple([0] + list(self.input_shape))).to(
                self.device)
        if not hasattr(self, 'phi'):
            print('Building phi', file=sys.stderr)
            with torch.no_grad():
                self.phi = torch.zeros_like(self.net.features(
                    inputs[0].unsqueeze(0)),
                                            requires_grad=False)
            assert not self.phi.requires_grad

        if not self.buffer.is_empty():
            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            inputs = torch.cat((inputs, buf_inputs))
            labels = torch.cat((labels, buf_labels))

        old_weights = self.net.get_params().detach().clone()

        self.opt.zero_grad()
        outputs = self.net(inputs)

        k = self.task_number

        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()

        first_loss = 0

        assert len(self.anchors) == self.dataset.N_CLASSES_PER_TASK * k

        if len(self.anchors) > 0:
            first_loss = loss.item()
            with torch.no_grad():
                pred_anchors = self.net(self.anchors)

            self.net.set_params(old_weights)
            pred_anchors -= self.net(self.anchors)
            loss = self.hal_lambda * (pred_anchors**2).mean()
            loss.backward()
            self.opt.step()

        with torch.no_grad():
            self.phi = self.beta * self.phi + (
                1 - self.beta) * self.net.features(
                    inputs[:real_batch_size]).mean(0)

        self.buffer.add_data(examples=not_aug_inputs,
                             labels=labels[:real_batch_size])

        return first_loss + loss.item()
class BasicManager(object):
	
	def __init__(self, session, device, id, action_shape, state_shape, concat_size=0, global_network=None, training=True):
		self.training = training
		self.session = session
		self.id = id
		self.device = device
		self.state_shape = state_shape
		self.set_model_size()
		if self.training:
			self.global_network = global_network
			# Gradient optimizer and clip range
			if not self.is_global_network():
				self.clip = self.global_network.clip
			else:
				self.initialize_gradient_optimizer()
			# Build agents
			self.model_list = []
			self.build_agents(state_shape=state_shape, action_shape=action_shape, concat_size=concat_size)
			# Build experience buffer
			if flags.replay_ratio > 0:
				if flags.prioritized_replay:
					self.experience_buffer = PrioritizedBuffer(size=flags.replay_buffer_size)
					# self.beta_schedule = LinearSchedule(flags.max_time_step, initial_p=0.4, final_p=1.0)
				else:
					self.experience_buffer = Buffer(size=flags.replay_buffer_size)
			if flags.predict_reward:
				self.reward_prediction_buffer = Buffer(size=flags.reward_prediction_buffer_size)
			# Bind optimizer to global
			if not self.is_global_network():
				self.bind_to_global(self.global_network)
			# Count based exploration	
			if flags.use_count_based_exploration_reward:
				self.projection = None
				self.projection_dataset = []
			if flags.print_loss:
				self._loss_list = [{} for _ in range(self.model_size)]
		else:
			self.global_network = None
			self.model_list = global_network.model_list
		# Statistics
		self._model_usage_list = deque()
			
	def is_global_network(self):
		return self.global_network is None
			
	def set_model_size(self):
		self.model_size = 1
		self.agents_set = set([0])
			
	def build_agents(self, state_shape, action_shape, concat_size):
		agent=eval('{}_Network'.format(flags.network))(
			session=self.session, 
			id='{0}_{1}'.format(self.id, 0), 
			device=self.device, 
			state_shape=state_shape, 
			action_shape=action_shape, 
			concat_size=concat_size, 
			clip=self.clip[0], 
			predict_reward=flags.predict_reward, 
			training = self.training
		)
		self.model_list.append(agent)
			
	def sync(self):
		# assert not self.is_global_network(), 'you are trying to sync the global network with itself'
		for i in range(self.model_size):
			agent = self.model_list[i]
			sync = self.sync_list[i]
			agent.sync(sync)
			
	def initialize_gradient_optimizer(self):
		self.global_step = []
		self.learning_rate = []
		self.clip = []
		self.gradient_optimizer = []
		for i in range(self.model_size):
		# global step
			self.global_step.append( tf.Variable(0, trainable=False) )
		# learning rate
			self.learning_rate.append( eval('tf.train.'+flags.alpha_annealing_function)(learning_rate=flags.alpha, global_step=self.global_step[i], decay_steps=flags.alpha_decay_steps, decay_rate=flags.alpha_decay_rate) if flags.alpha_decay else flags.alpha )
		# clip
			self.clip.append( eval('tf.train.'+flags.clip_annealing_function)(learning_rate=flags.clip, global_step=self.global_step[i], decay_steps=flags.clip_decay_steps, decay_rate=flags.clip_decay_rate) if flags.clip_decay else flags.clip )
		# gradient optimizer
			self.gradient_optimizer.append( eval('tf.train.'+flags.optimizer+'Optimizer')(learning_rate=self.learning_rate[i], use_locking=True) )
			
	def bind_to_global(self, global_network):
		self.sync_list = []
		for i in range(self.model_size):
			local_agent = self.get_model(i)
			global_agent = global_network.get_model(i)
			local_agent.minimize_local_loss(optimizer=global_network.gradient_optimizer[i], global_step=global_network.global_step[i], global_var_list=global_agent.get_shared_keys())
			self.sync_list.append(local_agent.bind_sync(global_agent)) # for syncing local network with global one

	def get_model(self, id):
		return self.model_list[id]
		
	def get_statistics(self):
		stats = {}
		if self.training:
			# build loss statistics
			if flags.print_loss:
				for i in range(self.model_size):
					for key, value in self._loss_list[i].items():
						stats['loss_{}{}_avg'.format(key,i)] = np.average(value)
		# build models usage statistics
		if self.model_size > 1:
			total_usage = 0
			usage_matrix = {}
			for u in self._model_usage_list:
				if not (u in usage_matrix):
					usage_matrix[u] = 0
				usage_matrix[u] += 1
				total_usage += 1
			for i in range(self.model_size):
				stats['model_{}'.format(i)] = 0
			for key, value in usage_matrix.items():
				stats['model_{}'.format(key)] = value/total_usage if total_usage != 0 else 0
		return stats
		
	def add_to_statistics(self, id):
		self._model_usage_list.append(id)
		if len(self._model_usage_list) > flags.match_count_for_evaluation:
			self._model_usage_list.popleft() # remove old statistics
		
	def get_shared_keys(self):
		vars = []
		for agent in self.model_list:
			vars += agent.get_shared_keys()
		return vars
		
	def reset(self):
		self.step = 0
		self.agent_id = 0
		# Internal states
		self.internal_states = None if flags.share_internal_state else [None]*self.model_size
		if self.training:
			# Count based exploration
			if flags.use_count_based_exploration_reward:
				self.hash_state_table = {}
			
	def initialize_new_batch(self):
		self.batch = ExperienceBatch(self.model_size)
		
	def estimate_value(self, agent_id, states, concats=None, internal_state=None):
		return self.get_model(agent_id).predict_value(states=states, concats=concats, internal_state=internal_state)
		
	def act(self, act_function, state, concat=None):
		agent_id = self.agent_id
		agent = self.get_model(agent_id)
		
		internal_state = self.internal_states if flags.share_internal_state else self.internal_states[agent_id]
		action_batch, value_batch, policy_batch, new_internal_state = agent.predict_action(states=[state], concats=[concat], internal_state=internal_state)
		if flags.share_internal_state:
			self.internal_states = new_internal_state
		else:
			self.internal_states[agent_id] = new_internal_state
		
		action, value, policy = action_batch[0], value_batch[0], policy_batch[0]
		new_state, extrinsic_reward, terminal = act_function(action)
		if self.training:
			if flags.clip_reward:
				extrinsic_reward = np.clip(extrinsic_reward, flags.min_reward, flags.max_reward)
			
		intrinsic_reward = 0
		if self.training:
			if flags.use_count_based_exploration_reward: # intrinsic reward
				intrinsic_reward += self.get_count_based_exploration_reward(new_state)

		total_reward = np.array([extrinsic_reward, intrinsic_reward], dtype=np.float32)
		if self.training:
			self.batch.add_action(agent_id=agent_id, state=state, concat=concat, action=action, policy=policy, reward=total_reward, value=value, internal_state=internal_state)
		# update step at the end of the action
		self.step += 1
		# return result
		return new_state, value, action, total_reward, terminal, policy

	def get_count_based_exploration_reward(self, new_state):
		if len(self.projection_dataset) < flags.projection_dataset_size:
			self.projection_dataset.append(new_state.flatten())
		if len(self.projection_dataset) == flags.projection_dataset_size:
			if self.projection is None:
				self.projection = SparseRandomProjection(n_components=flags.exploration_hash_size if flags.exploration_hash_size > 0 else 'auto') # http://scikit-learn.org/stable/modules/random_projection.html
			self.projection.fit(self.projection_dataset)
			self.projection_dataset = [] # reset
		if self.projection is not None:
			state_projection = self.projection.transform([new_state.flatten()])[0] # project to smaller dimension
			state_hash = ''.join('1' if x > 0 else '0' for x in state_projection) # build binary locality-sensitive hash
			if state_hash not in self.hash_state_table:
				self.hash_state_table[state_hash] = 1
			else:
				self.hash_state_table[state_hash] += 1
			exploration_bonus = 2/np.sqrt(self.hash_state_table[state_hash]) - 1 # in [-1,1]
			return flags.positive_exploration_coefficient*exploration_bonus if exploration_bonus > 0 else flags.negative_exploration_coefficient*exploration_bonus
		return 0
			
	def compute_discounted_cumulative_reward(self, batch):
		last_value = batch.bootstrap['value'] if 'value' in batch.bootstrap else 0.
		batch.compute_discounted_cumulative_reward(agents=self.agents_set, last_value=last_value, gamma=flags.gamma, lambd=flags.lambd)
		return batch
		
	def train(self, batch):
		# assert self.global_network is not None, 'Cannot train the global network.'
		states = batch.states
		internal_states = batch.internal_states
		concats = batch.concats
		actions = batch.actions
		policies = batch.policies
		values = batch.values
		rewards = batch.rewards
		dcr = batch.discounted_cumulative_rewards
		gae = batch.generalized_advantage_estimators
		batch_error = []
		for i in range(self.model_size):
			batch_size = len(states[i])
			if batch_size > 0:
				model = self.get_model(i)
				# reward prediction
				if model.predict_reward:
					sampled_batch = self.reward_prediction_buffer.sample()
					reward_prediction_states, reward_prediction_target = self.get_reward_prediction_tuple(sampled_batch)
				else:
					reward_prediction_states = None
					reward_prediction_target = None
				# train
				error, train_info = model.train(
					states=states[i], concats=concats[i],
					actions=actions[i], values=values[i],
					policies=policies[i],
					rewards=rewards[i],
					discounted_cumulative_rewards=dcr[i],
					generalized_advantage_estimators=gae[i],
					reward_prediction_states=reward_prediction_states,
					reward_prediction_target=reward_prediction_target,
					internal_state=internal_states[i][0]
				)
				batch_error.append(error)
				# loss statistics
				if flags.print_loss:
					for key, value in train_info.items():
						if key not in self._loss_list[i]:
							self._loss_list[i][key] = deque()
						self._loss_list[i][key].append(value)
						if len(self._loss_list[i][key]) > flags.match_count_for_evaluation: # remove old statistics
							self._loss_list[i][key].popleft()
		return batch_error
		
	def bootstrap(self, state, concat=None):
		agent_id = self.agent_id
		internal_state = self.internal_states if flags.share_internal_state else self.internal_states[agent_id]
		value_batch, _ = self.estimate_value(agent_id=agent_id, states=[state], concats=[concat], internal_state=internal_state)
		bootstrap = self.batch.bootstrap
		bootstrap['internal_state'] = internal_state
		bootstrap['agent_id'] = agent_id
		bootstrap['state'] = state
		bootstrap['concat'] = concat
		bootstrap['value'] = value_batch[0]
		
	def replay_value(self, batch): # replay values
		# replay values
		for (agent_id,pos) in batch.step_generator():
			concat, state, internal_state = batch.get_action(['concats','states','internal_states'], agent_id, pos)
			value_batch, _ = self.estimate_value(agent_id=agent_id, states=[state], concats=[concat], internal_state=internal_state)
			batch.set_action({'values':value_batch[0]}, agent_id, pos)
		if 'value' in batch.bootstrap:
			bootstrap = batch.bootstrap
			agent_id = bootstrap['agent_id']
			value_batch, _ = self.estimate_value(agent_id=agent_id, states=[bootstrap['state']], concats=[bootstrap['concat']], internal_state=bootstrap['internal_state'])
			bootstrap['value'] = value_batch[0]
		return self.compute_discounted_cumulative_reward(batch)
		
	def add_to_reward_prediction_buffer(self, batch):
		batch_size = batch.get_size(self.agents_set)
		if batch_size < 2:
			return
		batch_extrinsic_reward = batch.get_cumulative_reward(self.agents_set)[0]
		self.reward_prediction_buffer.put(batch=batch, type_id=1 if batch_extrinsic_reward != 0 else 0) # process batch only after sampling, for better perfomance
			
	def get_reward_prediction_tuple(self, batch):
		flat_states = [batch.get_action('states', agent_id, pos) for (agent_id,pos) in batch.step_generator(self.agents_set)]
		flat_rewards = [batch.get_action('rewards', agent_id, pos) for (agent_id,pos) in batch.step_generator(self.agents_set)]
		states_count = len(flat_states)
		length = min(3, states_count-1)
		start_idx = np.random.randint(states_count-length) if states_count > length else 0
		reward_prediction_states = [flat_states[start_idx+i] for i in range(length)]
		reward_prediction_target = np.zeros((1,3))
		
		target_reward = flat_rewards[start_idx+length][0] # use only extrinsic rewards
		if target_reward == 0:
			reward_prediction_target[0][0] = 1.0 # zero
		elif target_reward > 0:
			reward_prediction_target[0][1] = 1.0 # positive
		else:
			reward_prediction_target[0][2] = 1.0 # negative
		return reward_prediction_states, reward_prediction_target
			
	def add_to_replay_buffer(self, batch, batch_error):
		batch_size = batch.get_size(self.agents_set)
		if batch_size < 1:
			return
		batch_reward = batch.get_cumulative_reward(self.agents_set)
		batch_extrinsic_reward = batch_reward[0]
		batch_intrinsic_reward = batch_reward[1]
		batch_tot_reward = batch_extrinsic_reward + batch_intrinsic_reward
		if batch_tot_reward == 0 and flags.save_only_batches_with_reward:
			return
		if flags.replay_using_default_internal_state:
			batch.reset_internal_states()
		type_id = (1 if batch_intrinsic_reward > 0 else (2 if batch_extrinsic_reward > 0 else 0))
		if flags.prioritized_replay:
			self.experience_buffer.put(batch=batch, priority=batch_tot_reward, type_id=type_id)
		else:
			self.experience_buffer.put(batch=batch, type_id=type_id)
		
	def replay_experience(self):
		if not self.experience_buffer.has_atleast(flags.replay_start):
			return
		n = np.random.poisson(flags.replay_ratio)
		for _ in range(n):
			old_batch = self.experience_buffer.sample()
			self.train(self.replay_value(old_batch) if flags.replay_value else old_batch)
		
	def process_batch(self, global_step):
		batch = self.compute_discounted_cumulative_reward(self.batch)
		# reward prediction
		if flags.predict_reward:
			self.add_to_reward_prediction_buffer(batch) # do it before training, this way there will be at least one batch in the reward_prediction_buffer
			if self.reward_prediction_buffer.is_empty():
				return # cannot train without reward prediction, wait until reward_prediction_buffer is not empty
		# train
		batch_error = self.train(batch)
		# experience replay (after training!)
		if flags.replay_ratio > 0 and global_step > flags.replay_step:
			self.replay_experience()
			self.add_to_replay_buffer(batch, batch_error)
Exemple #9
0
class Fdr(ContinualModel):
    NAME = 'fdr'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Fdr, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.current_task = 0
        self.i = 0
        self.soft = torch.nn.Softmax(dim=1)
        self.logsoft = torch.nn.LogSoftmax(dim=1)

    def end_task(self, dataset):
        self.current_task += 1
        examples_per_task = self.args.buffer_size // self.current_task

        if self.current_task > 1:
            buf_x, buf_log, buf_tl = self.buffer.get_all_data()
            self.buffer.empty()

            for ttl in buf_tl.unique():
                idx = (buf_tl == ttl)
                ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx]
                first = min(ex.shape[0], examples_per_task)
                self.buffer.add_data(examples=ex[:first],
                                     logits=log[:first],
                                     task_labels=tasklab[:first])
        counter = 0
        with torch.no_grad():
            for i, data in enumerate(dataset.train_loader):
                inputs, labels, not_aug_inputs = data
                inputs = inputs.to(self.device)
                not_aug_inputs = not_aug_inputs.to(self.device)
                outputs = self.net(inputs)
                if examples_per_task - counter < 0:
                    break
                self.buffer.add_data(
                    examples=not_aug_inputs[:(examples_per_task - counter)],
                    logits=outputs.data[:(examples_per_task - counter)],
                    task_labels=(torch.ones(self.args.batch_size) *
                                 (self.current_task - 1))[:(examples_per_task -
                                                            counter)])
                counter += self.args.batch_size

    def observe(self, inputs, labels, not_aug_inputs):
        self.i += 1

        self.opt.zero_grad()
        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()
        if not self.buffer.is_empty():
            self.opt.zero_grad()
            buf_inputs, buf_logits, _ = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            buf_outputs = self.net(buf_inputs)
            loss = torch.norm(
                self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean()
            assert not torch.isnan(loss)
            loss.backward()
            self.opt.step()

        return loss.item()
Exemple #10
0
def train_cl(train_set, test_set, model, loss, optimizer, device, config):
    """
    :param train_set: Train set
    :param test_set: Test set
    :param model: PyTorch model
    :param loss: loss function
    :param optimizer: optimizer
    :param device: device cuda/cpu
    :param config: configuration
    """

    name = ""
    # global_writer = SummaryWriter('./runs/continual/train/global/' + datetime.datetime.now().strftime('%m_%d_%H_%M'))
    global_writer = SummaryWriter('./runs/continual/train/global/' + name)
    buffer = Buffer(config['buffer_size'], device)
    accuracy = []
    text = open("result_" + name + ".txt",
                "w")  # TODO save results in a .txt file

    # Eval without training
    random_accuracy = evaluate_past(model,
                                    len(test_set) - 1, test_set, loss, device)
    text.write("Evaluation before training" + '\n')
    for a in random_accuracy:
        text.write(f"{a:.2f}% ")
    text.write('\n')

    for index, data_set in enumerate(train_set):
        model.train()
        print(f"----- DOMAIN {index} -----")
        print("Training model...")
        train_loader = DataLoader(data_set,
                                  batch_size=config['batch_size'],
                                  shuffle=False)

        for epoch in tqdm(range(config['epochs'])):
            epoch_loss = []
            epoch_acc = []
            for i, (x, y) in enumerate(train_loader):
                optimizer.zero_grad()

                inputs = x.to(device)
                labels = y.to(device)
                if not buffer.is_empty():
                    # Strategy 50/50
                    # From batch of 64 (dataloader) to 64 + 64 (dataloader + replay)
                    buf_input, buf_label = buffer.get_data(
                        config['batch_size'])
                    inputs = torch.cat((inputs, torch.stack(buf_input)))
                    labels = torch.cat((labels, torch.stack(buf_label)))

                y_pred = model(inputs)
                s_loss = loss(y_pred.squeeze(1), labels)
                acc = binary_accuracy(y_pred.squeeze(1), labels)
                # METRICHE INTERNE EPOCA
                epoch_loss.append(s_loss.item())
                epoch_acc.append(acc.item())

                s_loss.backward()
                optimizer.step()

                if epoch == 0:
                    buffer.add_data(examples=x.to(device), labels=y.to(device))

            global_writer.add_scalar('Train_global/Loss',
                                     statistics.mean(epoch_loss),
                                     epoch + (config['epochs'] * index))
            global_writer.add_scalar('Train_global/Accuracy',
                                     statistics.mean(epoch_acc),
                                     epoch + (config['epochs'] * index))

            # domain_writer.add_scalar(f'Train_D{index}/Loss', statistics.mean(epoch_loss), epoch)
            # domain_writer.add_scalar(f'Train_D{index}/Accuracy', statistics.mean(epoch_acc), epoch)

            if epoch % 100 == 0:
                print(
                    f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} '
                    f'| Acc: {statistics.mean(epoch_acc):.5f}')

            # Last epoch (only for stats)
            if epoch == 499:
                print(
                    f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} '
                    f'| Acc: {statistics.mean(epoch_acc):.5f}')

        # Test on domain just trained + old domains
        evaluation = evaluate_past(model, index, test_set, loss, device)
        accuracy.append(evaluation)
        text.write(f"Evaluation after domain {index}" + '\n')
        for a in evaluation:
            text.write(f"{a:.2f}% ")
        text.write('\n')

        if index != len(train_set) - 1:
            accuracy[index].append(
                evaluate_next(model, index, test_set, loss, device))

    # Check buffer distribution
    buffer.check_distribution()

    # Compute transfer metrics
    backward = backward_transfer(accuracy)
    forward = forward_transfer(accuracy, random_accuracy)
    forget = forgetting(accuracy)
    print(f'Backward transfer: {backward}')  # todo Sono in %?
    print(f'Forward transfer: {forward}')
    print(f'Forgetting: {forget}')

    text.write(f"Backward: {backward}\n")
    text.write(f"Forward: {forward}\n")
    text.write(f"Forgetting: {forget}\n")
    text.close()
Exemple #11
0
class OCILFAST(ContinualModel):
    NAME = 'OCILFAST'
    COMPATIBILITY = ['class-il', 'task-il']

    def __init__(self, net, loss, args, transform):
        super(OCILFAST, self).__init__(net, loss, args, transform)

        self.nets = []
        self.c = []
        self.threshold = []

        self.nu = self.args.nu
        self.eta = self.args.eta
        self.eps = self.args.eps
        self.embedding_dim = self.args.embedding_dim
        self.weight_decay = self.args.weight_decay
        self.margin = self.args.margin

        self.current_task = 0
        self.cpt = None
        self.nc = None
        self.eye = None
        self.buffer_size = self.args.buffer_size
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.nf = self.args.nf

        if self.args.dataset == 'seq-cifar10' or self.args.dataset == 'seq-mnist':
            self.input_offset = -0.5
        elif self.args.dataset == 'seq-tinyimg':
            self.input_offset = 0
        else:
            self.input_offset = 0

    # 任务初始化
    def begin_task(self, dataset):

        if self.cpt is None:
            self.cpt = dataset.N_CLASSES_PER_TASK
            self.nc = dataset.N_TASKS * self.cpt
            self.eye = torch.tril(torch.ones((self.nc, self.nc))).bool().to(
                self.device)  # 下三角包括对角线为True,上三角为False,用于掩码

        if len(self.nets) == 0:
            for i in range(self.nc):
                self.nets.append(
                    get_backbone(self.net, self.embedding_dim, self.nc,
                                 self.nf).to(self.device))
                self.c.append(
                    torch.ones(self.embedding_dim, device=self.device))

        self.current_task += 1

    def train_model(self, dataset, train_loader):

        categories = list(
            range((self.current_task - 1) * self.cpt,
                  (self.current_task) * self.cpt))
        print('==========\t task: %d\t categories:' % self.current_task,
              categories, '\t==========')
        if self.args.print_file:
            print('==========\t task: %d\t categories:' % self.current_task,
                  categories,
                  '\t==========',
                  file=self.args.print_file)

        for category in categories:
            losses = []

            if category > 0:
                self.reset_train_loader(train_loader, category)

            for epoch in range(self.args.n_epochs):

                avg_loss, maxloss, posdist, negdist, gloloss = self.train_category(
                    train_loader, category, epoch)

                losses.append(avg_loss)
                if epoch == 0 or (epoch + 1) % 5 == 0:
                    print("epoch: %d\t task: %d \t category: %d \t loss: %f" %
                          (epoch + 1, self.current_task, category, avg_loss))

                    if self.args.print_file:
                        print(
                            "epoch: %d\t task: %d \t category: %d \t loss: %f"
                            %
                            (epoch + 1, self.current_task, category, avg_loss),
                            file=self.args.print_file)

                    plt.figure(figsize=(20, 12))

                    ax = plt.subplot(2, 2, 1)
                    ax.set_title('maxloss')
                    plt.xlim((0, 2))
                    if maxloss is not None:
                        try:
                            sns.distplot(maxloss)
                        except:
                            pass

                    ax = plt.subplot(2, 2, 2)
                    ax.set_title('posdist')
                    plt.xlim((0, 2))
                    try:
                        sns.distplot(posdist)
                    except:
                        print(posdist)

                    ax = plt.subplot(2, 2, 3)
                    ax.set_title('negdist')
                    plt.xlim((0, 2))
                    try:
                        sns.distplot(negdist)
                    except:
                        print(negdist)

                    ax = plt.subplot(2, 2, 4)
                    ax.set_title('gloloss')
                    plt.xlim((0, 2))
                    try:
                        sns.distplot(gloloss)
                    except:
                        print(gloloss)

                    plt.savefig("../" + self.args.img_dir +
                                "/loss-cat%d-epoch%d.png" % (category, epoch))
                    plt.clf()

            x = list(range(len(losses)))
            plt.plot(x, losses)
            plt.savefig("../" + self.args.img_dir + "/loss-cat%d.png" %
                        (category))
            plt.clf()

        self.fill_buffer(train_loader)

    def reset_train_loader(self, train_loader, category):

        dataset = train_loader.dataset
        input = dataset.data
        loader = DataLoader(dataset,
                            batch_size=self.args.batch_size,
                            shuffle=False)

        inputs = []
        targets = []
        prev_dists = []

        prev_categories = list(range(category))
        print('prev_categories', prev_categories)
        if self.args.print_file:
            print('prev_categories',
                  prev_categories,
                  file=self.args.print_file)
        for i, data in enumerate(loader):
            input, target, _ = data
            _, prev_dist = self.predict(input, prev_categories)

            inputs.append(input.detach().cpu())
            targets.append(target.detach().cpu())
            prev_dists.append(prev_dist.detach().cpu())

        inputs = torch.cat(inputs, dim=0)
        targets = torch.cat(targets, dim=0)
        prev_dists = torch.cat(prev_dists, dim=0)
        dataset.set_prevdist(prev_dists)

    def train_category(self, data_loader, category: int, epoch_id):

        self.init_center_c(data_loader, category)
        c = self.c[category]

        network = self.nets[category].to(self.device)
        network.train()

        optimizer = SGD(network.parameters(),
                        lr=self.args.lr,
                        weight_decay=self.weight_decay)
        avg_loss = 0.0
        sample_num = 0

        maxloss = []
        posdist = []
        negdist = []
        gloloss = []

        prev_categories = list(range(category))
        for i, data in enumerate(data_loader):
            inputs, semi_targets, prev_dists = data
            inputs = inputs.to(self.device)
            semi_targets = semi_targets.to(self.device)
            prev_dists = prev_dists.to(self.device)

            if (not self.buffer.is_empty()) and self.args.buffer_size > 0:
                buf_inputs, buf_labels = self.buffer.get_data(
                    self.args.minibatch_size, transform=self.transform)
                # print(buf_inputs[0])
                inputs = torch.cat((inputs, buf_inputs))
                semi_targets = torch.cat((semi_targets, buf_labels))

            # Zero the network parameter gradients
            optimizer.zero_grad()

            # 注意网络的输入要减去0.5
            outputs = network(inputs + self.input_offset)

            dists = torch.sum((outputs - c)**2, dim=1)
            pos_dist_loss = torch.relu(dists - self.args.r)

            if category > 0:
                max_scores = torch.relu(dists.view(-1, 1) - prev_dists)
                max_loss = torch.sum(max_scores,
                                     dim=1) * self.margin / category

                loss_pos = pos_dist_loss + max_loss
                loss_neg = self.eta * dists**-1

                pos_max_loss = max_loss[semi_targets == category]
                maxloss.append(pos_max_loss.detach().cpu().data.numpy())

            else:
                loss_pos = pos_dist_loss
                loss_neg = self.eta * dists**-1

            losses = torch.where(semi_targets == category, loss_pos, loss_neg)
            gloloss.append(losses.detach().cpu().data.numpy())
            loss = torch.mean(losses)

            loss.backward()
            optimizer.step()

            # 记录损失部分
            pos_dist = pos_dist_loss[semi_targets == category]
            posdist.append(pos_dist.detach().cpu().data.numpy())

            neg_dist = loss_neg[semi_targets != category]
            negdist.append(neg_dist.detach().cpu().data.numpy())

            avg_loss += loss.item()
            sample_num += inputs.shape[0]

            # 旧类别只训练一次
            if category < (self.current_task - 1) * self.cpt:
                break

        avg_loss /= sample_num
        if len(maxloss) > 0:
            maxloss = np.hstack(maxloss)
        else:
            maxloss = None
        posdist = np.hstack(posdist)
        negdist = np.hstack(negdist)
        gloloss = np.hstack(gloloss)
        return avg_loss, maxloss, posdist, negdist, gloloss

    def fill_buffer(self, train_loader):
        for data in train_loader:
            # get the inputs of the batch
            inputs, semi_targets, not_aug_inputs = data
            self.buffer.add_data(examples=not_aug_inputs, labels=semi_targets)

    def init_center_c(self, train_loader: DataLoader, category):
        """Initialize hypersphere center c as the mean from an initial forward pass on the data."""
        n_samples = 0
        c = 0

        net = self.nets[category].to(self.device)

        net.eval()
        with torch.no_grad():
            for data in train_loader:
                # get the inputs of the batch
                inputs, semi_targets, not_aug_inputs = data
                inputs = inputs.to(self.device)
                semi_targets = semi_targets.to(self.device)
                outputs = net(inputs + self.input_offset)
                outputs = outputs[semi_targets == category]  # 取所有正样本来进行圆心初始化
                # print(outputs)
                n_samples += outputs.shape[0]
                c += torch.sum(outputs, dim=0)

        c /= n_samples

        # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
        c[(abs(c) < self.eps) & (c < 0)] = -self.eps
        c[(abs(c) < self.eps) & (c > 0)] = self.eps
        self.c[category] = c.to(self.device)

    def get_score(self, dist, category):
        score = 1 / (dist + 1e-6)

        return score

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        categories = list(range(self.current_task * self.cpt))
        return self.predict(x, categories)[0]

    def predict(self, inputs: torch.Tensor, categories):
        inputs = inputs.to(self.device)
        outcome, dists = [], []
        with torch.no_grad():
            for i in categories:
                net = self.nets[i]
                net.to(self.device)
                net.eval()

                c = self.c[i].to(self.device)

                pred = net(inputs + self.input_offset)
                dist = torch.sum((pred - c)**2, dim=1)

                scores = self.get_score(dist, i)

                outcome.append(scores.view(-1, 1))
                dists.append(dist.view(-1, 1))

        outcome = torch.cat(outcome, dim=1)
        dists = torch.cat(dists, dim=1)
        return outcome, dists