class Er(ContinualModel): NAME = 'er' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Er, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) def observe(self, inputs, labels, not_aug_inputs): real_batch_size = inputs.shape[0] self.opt.zero_grad() if not self.buffer.is_empty(): buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) inputs = torch.cat((inputs, buf_inputs)) labels = torch.cat((labels, buf_labels)) outputs = self.net(inputs) loss = self.loss(outputs, labels) loss.backward() self.opt.step() self.buffer.add_data(examples=not_aug_inputs, labels=labels[:real_batch_size]) return loss.item()
class Der(ContinualModel): NAME = 'der' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Der, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) def observe(self, inputs, labels, not_aug_inputs): self.opt.zero_grad() outputs = self.net(inputs) loss = self.loss(outputs, labels) if not self.buffer.is_empty(): buf_inputs, buf_logits = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) buf_outputs = self.net(buf_inputs) loss += self.args.alpha * F.mse_loss(buf_outputs, buf_logits) loss.backward() self.opt.step() self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data) return loss.item()
class Mer(ContinualModel): NAME = 'mer' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Mer, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) def draw_batches(self, inp, lab): batches = [] for i in range(self.args.batch_num): if not self.buffer.is_empty(): buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) inputs = torch.cat((buf_inputs, inp.unsqueeze(0))) labels = torch.cat( (buf_labels, torch.tensor([lab]).to(self.device))) batches.append((inputs, labels)) else: batches.append( (inp.unsqueeze(0), torch.tensor([lab]).unsqueeze(0).to(self.device))) return batches def observe(self, inputs, labels, not_aug_inputs): batches = self.draw_batches(inputs, labels) theta_A0 = self.net.get_params().data.clone() for i in range(self.args.batch_num): theta_Wi0 = self.net.get_params().data.clone() batch_inputs, batch_labels = batches[i] # within-batch step self.opt.zero_grad() outputs = self.net(batch_inputs) loss = self.loss(outputs, batch_labels.squeeze(-1)) loss.backward() self.opt.step() # within batch reptile meta-update new_params = theta_Wi0 + self.args.beta * (self.net.get_params() - theta_Wi0) self.net.set_params(new_params) self.buffer.add_data(examples=not_aug_inputs.unsqueeze(0), labels=labels) # across batch reptile meta-update new_new_params = theta_A0 + self.args.gamma * (self.net.get_params() - theta_A0) self.net.set_params(new_new_params) return loss.item()
class AGem(ContinualModel): NAME = 'agem' COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] def __init__(self, backbone, loss, args, transform): super(AGem, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) self.grad_dims = [] for param in self.parameters(): self.grad_dims.append(param.data.numel()) self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device) self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device) self.transform = transform if self.args.iba else None def end_task(self, dataset): samples_per_task = self.args.buffer_size // dataset.N_TASKS loader = dataset.not_aug_dataloader(self.args, samples_per_task) cur_x, cur_y = next(iter(loader))[:2] self.buffer.add_data(examples=cur_x.to(self.device), labels=cur_y.to(self.device)) def observe(self, inputs, labels, not_aug_inputs): self.zero_grad() p = self.net.forward(inputs) loss = self.loss(p, labels) loss.backward() if not self.buffer.is_empty(): store_grad(self.parameters, self.grad_xy, self.grad_dims) buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) self.net.zero_grad() buf_outputs = self.net.forward(buf_inputs) penalty = self.loss(buf_outputs, buf_labels) penalty.backward() store_grad(self.parameters, self.grad_er, self.grad_dims) dot_prod = torch.dot(self.grad_xy, self.grad_er) if dot_prod.item() < 0: g_tilde = project(gxy=self.grad_xy, ger=self.grad_er) overwrite_grad(self.parameters, g_tilde, self.grad_dims) else: overwrite_grad(self.parameters, self.grad_xy, self.grad_dims) self.opt.step() return loss.item()
class AGemr(ContinualModel): NAME = 'agem_r' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(AGemr, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) self.grad_dims = [] for param in self.parameters(): self.grad_dims.append(param.data.numel()) self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device) self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device) self.current_task = 0 def observe(self, inputs, labels, not_aug_inputs): self.zero_grad() p = self.net.forward(inputs) loss = self.loss(p, labels) loss.backward() if not self.buffer.is_empty(): store_grad(self.parameters, self.grad_xy, self.grad_dims) buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size) self.net.zero_grad() buf_outputs = self.net.forward(buf_inputs) penalty = self.loss(buf_outputs, buf_labels) penalty.backward() store_grad(self.parameters, self.grad_er, self.grad_dims) dot_prod = torch.dot(self.grad_xy, self.grad_er) if dot_prod.item() < 0: g_tilde = project(gxy=self.grad_xy, ger=self.grad_er) overwrite_grad(self.parameters, g_tilde, self.grad_dims) else: overwrite_grad(self.parameters, self.grad_xy, self.grad_dims) self.opt.step() self.buffer.add_data(examples=not_aug_inputs, labels=labels) return loss.item()
class Gem(ContinualModel): NAME = 'gem' COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] def __init__(self, backbone, loss, args, transform): super(Gem, self).__init__(backbone, loss, args, transform) self.current_task = 0 self.buffer = Buffer(self.args.buffer_size, self.device) self.transform = transform # Allocate temporary synaptic memory self.grad_dims = [] for pp in self.parameters(): self.grad_dims.append(pp.data.numel()) self.grads_cs = [] self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device) self.transform = transform if self.args.iba else None def end_task(self, dataset): self.current_task += 1 self.grads_cs.append( torch.zeros(np.sum(self.grad_dims)).to(self.device)) # add data to the buffer samples_per_task = self.args.buffer_size // dataset.N_TASKS loader = dataset.not_aug_dataloader(self.args, samples_per_task) cur_x, cur_y = next(iter(loader))[:2] self.buffer.add_data( examples=cur_x.to(self.device), labels=cur_y.to(self.device), task_labels=torch.ones(samples_per_task, dtype=torch.long).to( self.device) * (self.current_task - 1)) def observe(self, inputs, labels, not_aug_inputs): if not self.buffer.is_empty(): buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data( self.args.buffer_size, transform=self.transform) for tt in buf_task_labels.unique(): # compute gradient on the memory buffer self.opt.zero_grad() cur_task_inputs = buf_inputs[buf_task_labels == tt] cur_task_labels = buf_labels[buf_task_labels == tt] for i in range( math.ceil(len(cur_task_inputs) / self.args.batch_size)): cur_task_outputs = self.forward( cur_task_inputs[i * self.args.batch_size:(i + 1) * self.args.batch_size]) penalty = self.loss( cur_task_outputs, cur_task_labels[i * self.args.batch_size:(i + 1) * self.args.batch_size], reduction='sum') / cur_task_inputs.shape[0] penalty.backward() store_grad(self.parameters, self.grads_cs[tt], self.grad_dims) # cur_task_outputs = self.forward(cur_task_inputs) # penalty = self.loss(cur_task_outputs, cur_task_labels) # penalty.backward() # store_grad(self.parameters, self.grads_cs[tt], self.grad_dims) # now compute the grad on the current data self.opt.zero_grad() outputs = self.forward(inputs) loss = self.loss(outputs, labels) loss.backward() # check if gradient violates buffer constraints if not self.buffer.is_empty(): # copy gradient store_grad(self.parameters, self.grads_da, self.grad_dims) dot_prod = torch.mm(self.grads_da.unsqueeze(0), torch.stack(self.grads_cs).T) if (dot_prod < 0).sum() != 0: project2cone2(self.grads_da.unsqueeze(1), torch.stack(self.grads_cs).T, margin=self.args.gamma) # copy gradients back overwrite_grad(self.parameters, self.grads_da, self.grad_dims) self.opt.step() return loss.item()
class HAL(ContinualModel): NAME = 'hal' COMPATIBILITY = ['class-il', 'domain-il', 'task-il'] def __init__(self, backbone, loss, args, transform): super(HAL, self).__init__(backbone, loss, args, transform) self.task_number = 0 self.buffer = Buffer(self.args.buffer_size, self.device, get_dataset(args).N_TASKS, mode='ring') self.hal_lambda = args.hal_lambda self.beta = args.beta self.gamma = args.gamma self.anchor_optimization_steps = 100 self.finetuning_epochs = 1 self.dataset = get_dataset(args) self.spare_model = self.dataset.get_backbone() self.spare_model.to(self.device) self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr) def end_task(self, dataset): self.task_number += 1 # ring buffer mgmt (if we are not loading if self.task_number > self.buffer.task_number: self.buffer.num_seen_examples = 0 self.buffer.task_number = self.task_number # get anchors (provided that we are not loading the model if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK: self.get_anchors(dataset) del self.phi def get_anchors(self, dataset): theta_t = self.net.get_params().detach().clone() self.spare_model.set_params(theta_t) # fine tune on memory buffer for _ in range(self.finetuning_epochs): inputs, labels = self.buffer.get_data(self.args.batch_size, transform=self.transform) self.spare_opt.zero_grad() out = self.spare_model(inputs) loss = self.loss(out, labels) loss.backward() self.spare_opt.step() theta_m = self.spare_model.get_params().detach().clone() classes_for_this_task = np.unique(dataset.train_loader.dataset.targets) for a_class in classes_for_this_task: e_t = torch.rand(self.input_shape, requires_grad=True, device=self.device) e_t_opt = SGD([e_t], lr=self.args.lr) print(file=sys.stderr) for i in range(self.anchor_optimization_steps): e_t_opt.zero_grad() cum_loss = 0 self.spare_opt.zero_grad() self.spare_model.set_params(theta_m.detach().clone()) loss = -torch.sum( self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device))) loss.backward() cum_loss += loss.item() self.spare_opt.zero_grad() self.spare_model.set_params(theta_t.detach().clone()) loss = torch.sum( self.loss(self.spare_model(e_t.unsqueeze(0)), torch.tensor([a_class]).to(self.device))) loss.backward() cum_loss += loss.item() self.spare_opt.zero_grad() loss = torch.sum(self.gamma * (self.spare_model.features(e_t.unsqueeze(0)) - self.phi)**2) assert not self.phi.requires_grad loss.backward() cum_loss += loss.item() e_t_opt.step() e_t = e_t.detach() e_t.requires_grad = False self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0))) del e_t print('Total anchors:', len(self.anchors), file=sys.stderr) self.spare_model.zero_grad() def observe(self, inputs, labels, not_aug_inputs): real_batch_size = inputs.shape[0] if not hasattr(self, 'input_shape'): self.input_shape = inputs.shape[1:] if not hasattr(self, 'anchors'): self.anchors = torch.zeros(tuple([0] + list(self.input_shape))).to( self.device) if not hasattr(self, 'phi'): print('Building phi', file=sys.stderr) with torch.no_grad(): self.phi = torch.zeros_like(self.net.features( inputs[0].unsqueeze(0)), requires_grad=False) assert not self.phi.requires_grad if not self.buffer.is_empty(): buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) inputs = torch.cat((inputs, buf_inputs)) labels = torch.cat((labels, buf_labels)) old_weights = self.net.get_params().detach().clone() self.opt.zero_grad() outputs = self.net(inputs) k = self.task_number loss = self.loss(outputs, labels) loss.backward() self.opt.step() first_loss = 0 assert len(self.anchors) == self.dataset.N_CLASSES_PER_TASK * k if len(self.anchors) > 0: first_loss = loss.item() with torch.no_grad(): pred_anchors = self.net(self.anchors) self.net.set_params(old_weights) pred_anchors -= self.net(self.anchors) loss = self.hal_lambda * (pred_anchors**2).mean() loss.backward() self.opt.step() with torch.no_grad(): self.phi = self.beta * self.phi + ( 1 - self.beta) * self.net.features( inputs[:real_batch_size]).mean(0) self.buffer.add_data(examples=not_aug_inputs, labels=labels[:real_batch_size]) return first_loss + loss.item()
class BasicManager(object): def __init__(self, session, device, id, action_shape, state_shape, concat_size=0, global_network=None, training=True): self.training = training self.session = session self.id = id self.device = device self.state_shape = state_shape self.set_model_size() if self.training: self.global_network = global_network # Gradient optimizer and clip range if not self.is_global_network(): self.clip = self.global_network.clip else: self.initialize_gradient_optimizer() # Build agents self.model_list = [] self.build_agents(state_shape=state_shape, action_shape=action_shape, concat_size=concat_size) # Build experience buffer if flags.replay_ratio > 0: if flags.prioritized_replay: self.experience_buffer = PrioritizedBuffer(size=flags.replay_buffer_size) # self.beta_schedule = LinearSchedule(flags.max_time_step, initial_p=0.4, final_p=1.0) else: self.experience_buffer = Buffer(size=flags.replay_buffer_size) if flags.predict_reward: self.reward_prediction_buffer = Buffer(size=flags.reward_prediction_buffer_size) # Bind optimizer to global if not self.is_global_network(): self.bind_to_global(self.global_network) # Count based exploration if flags.use_count_based_exploration_reward: self.projection = None self.projection_dataset = [] if flags.print_loss: self._loss_list = [{} for _ in range(self.model_size)] else: self.global_network = None self.model_list = global_network.model_list # Statistics self._model_usage_list = deque() def is_global_network(self): return self.global_network is None def set_model_size(self): self.model_size = 1 self.agents_set = set([0]) def build_agents(self, state_shape, action_shape, concat_size): agent=eval('{}_Network'.format(flags.network))( session=self.session, id='{0}_{1}'.format(self.id, 0), device=self.device, state_shape=state_shape, action_shape=action_shape, concat_size=concat_size, clip=self.clip[0], predict_reward=flags.predict_reward, training = self.training ) self.model_list.append(agent) def sync(self): # assert not self.is_global_network(), 'you are trying to sync the global network with itself' for i in range(self.model_size): agent = self.model_list[i] sync = self.sync_list[i] agent.sync(sync) def initialize_gradient_optimizer(self): self.global_step = [] self.learning_rate = [] self.clip = [] self.gradient_optimizer = [] for i in range(self.model_size): # global step self.global_step.append( tf.Variable(0, trainable=False) ) # learning rate self.learning_rate.append( eval('tf.train.'+flags.alpha_annealing_function)(learning_rate=flags.alpha, global_step=self.global_step[i], decay_steps=flags.alpha_decay_steps, decay_rate=flags.alpha_decay_rate) if flags.alpha_decay else flags.alpha ) # clip self.clip.append( eval('tf.train.'+flags.clip_annealing_function)(learning_rate=flags.clip, global_step=self.global_step[i], decay_steps=flags.clip_decay_steps, decay_rate=flags.clip_decay_rate) if flags.clip_decay else flags.clip ) # gradient optimizer self.gradient_optimizer.append( eval('tf.train.'+flags.optimizer+'Optimizer')(learning_rate=self.learning_rate[i], use_locking=True) ) def bind_to_global(self, global_network): self.sync_list = [] for i in range(self.model_size): local_agent = self.get_model(i) global_agent = global_network.get_model(i) local_agent.minimize_local_loss(optimizer=global_network.gradient_optimizer[i], global_step=global_network.global_step[i], global_var_list=global_agent.get_shared_keys()) self.sync_list.append(local_agent.bind_sync(global_agent)) # for syncing local network with global one def get_model(self, id): return self.model_list[id] def get_statistics(self): stats = {} if self.training: # build loss statistics if flags.print_loss: for i in range(self.model_size): for key, value in self._loss_list[i].items(): stats['loss_{}{}_avg'.format(key,i)] = np.average(value) # build models usage statistics if self.model_size > 1: total_usage = 0 usage_matrix = {} for u in self._model_usage_list: if not (u in usage_matrix): usage_matrix[u] = 0 usage_matrix[u] += 1 total_usage += 1 for i in range(self.model_size): stats['model_{}'.format(i)] = 0 for key, value in usage_matrix.items(): stats['model_{}'.format(key)] = value/total_usage if total_usage != 0 else 0 return stats def add_to_statistics(self, id): self._model_usage_list.append(id) if len(self._model_usage_list) > flags.match_count_for_evaluation: self._model_usage_list.popleft() # remove old statistics def get_shared_keys(self): vars = [] for agent in self.model_list: vars += agent.get_shared_keys() return vars def reset(self): self.step = 0 self.agent_id = 0 # Internal states self.internal_states = None if flags.share_internal_state else [None]*self.model_size if self.training: # Count based exploration if flags.use_count_based_exploration_reward: self.hash_state_table = {} def initialize_new_batch(self): self.batch = ExperienceBatch(self.model_size) def estimate_value(self, agent_id, states, concats=None, internal_state=None): return self.get_model(agent_id).predict_value(states=states, concats=concats, internal_state=internal_state) def act(self, act_function, state, concat=None): agent_id = self.agent_id agent = self.get_model(agent_id) internal_state = self.internal_states if flags.share_internal_state else self.internal_states[agent_id] action_batch, value_batch, policy_batch, new_internal_state = agent.predict_action(states=[state], concats=[concat], internal_state=internal_state) if flags.share_internal_state: self.internal_states = new_internal_state else: self.internal_states[agent_id] = new_internal_state action, value, policy = action_batch[0], value_batch[0], policy_batch[0] new_state, extrinsic_reward, terminal = act_function(action) if self.training: if flags.clip_reward: extrinsic_reward = np.clip(extrinsic_reward, flags.min_reward, flags.max_reward) intrinsic_reward = 0 if self.training: if flags.use_count_based_exploration_reward: # intrinsic reward intrinsic_reward += self.get_count_based_exploration_reward(new_state) total_reward = np.array([extrinsic_reward, intrinsic_reward], dtype=np.float32) if self.training: self.batch.add_action(agent_id=agent_id, state=state, concat=concat, action=action, policy=policy, reward=total_reward, value=value, internal_state=internal_state) # update step at the end of the action self.step += 1 # return result return new_state, value, action, total_reward, terminal, policy def get_count_based_exploration_reward(self, new_state): if len(self.projection_dataset) < flags.projection_dataset_size: self.projection_dataset.append(new_state.flatten()) if len(self.projection_dataset) == flags.projection_dataset_size: if self.projection is None: self.projection = SparseRandomProjection(n_components=flags.exploration_hash_size if flags.exploration_hash_size > 0 else 'auto') # http://scikit-learn.org/stable/modules/random_projection.html self.projection.fit(self.projection_dataset) self.projection_dataset = [] # reset if self.projection is not None: state_projection = self.projection.transform([new_state.flatten()])[0] # project to smaller dimension state_hash = ''.join('1' if x > 0 else '0' for x in state_projection) # build binary locality-sensitive hash if state_hash not in self.hash_state_table: self.hash_state_table[state_hash] = 1 else: self.hash_state_table[state_hash] += 1 exploration_bonus = 2/np.sqrt(self.hash_state_table[state_hash]) - 1 # in [-1,1] return flags.positive_exploration_coefficient*exploration_bonus if exploration_bonus > 0 else flags.negative_exploration_coefficient*exploration_bonus return 0 def compute_discounted_cumulative_reward(self, batch): last_value = batch.bootstrap['value'] if 'value' in batch.bootstrap else 0. batch.compute_discounted_cumulative_reward(agents=self.agents_set, last_value=last_value, gamma=flags.gamma, lambd=flags.lambd) return batch def train(self, batch): # assert self.global_network is not None, 'Cannot train the global network.' states = batch.states internal_states = batch.internal_states concats = batch.concats actions = batch.actions policies = batch.policies values = batch.values rewards = batch.rewards dcr = batch.discounted_cumulative_rewards gae = batch.generalized_advantage_estimators batch_error = [] for i in range(self.model_size): batch_size = len(states[i]) if batch_size > 0: model = self.get_model(i) # reward prediction if model.predict_reward: sampled_batch = self.reward_prediction_buffer.sample() reward_prediction_states, reward_prediction_target = self.get_reward_prediction_tuple(sampled_batch) else: reward_prediction_states = None reward_prediction_target = None # train error, train_info = model.train( states=states[i], concats=concats[i], actions=actions[i], values=values[i], policies=policies[i], rewards=rewards[i], discounted_cumulative_rewards=dcr[i], generalized_advantage_estimators=gae[i], reward_prediction_states=reward_prediction_states, reward_prediction_target=reward_prediction_target, internal_state=internal_states[i][0] ) batch_error.append(error) # loss statistics if flags.print_loss: for key, value in train_info.items(): if key not in self._loss_list[i]: self._loss_list[i][key] = deque() self._loss_list[i][key].append(value) if len(self._loss_list[i][key]) > flags.match_count_for_evaluation: # remove old statistics self._loss_list[i][key].popleft() return batch_error def bootstrap(self, state, concat=None): agent_id = self.agent_id internal_state = self.internal_states if flags.share_internal_state else self.internal_states[agent_id] value_batch, _ = self.estimate_value(agent_id=agent_id, states=[state], concats=[concat], internal_state=internal_state) bootstrap = self.batch.bootstrap bootstrap['internal_state'] = internal_state bootstrap['agent_id'] = agent_id bootstrap['state'] = state bootstrap['concat'] = concat bootstrap['value'] = value_batch[0] def replay_value(self, batch): # replay values # replay values for (agent_id,pos) in batch.step_generator(): concat, state, internal_state = batch.get_action(['concats','states','internal_states'], agent_id, pos) value_batch, _ = self.estimate_value(agent_id=agent_id, states=[state], concats=[concat], internal_state=internal_state) batch.set_action({'values':value_batch[0]}, agent_id, pos) if 'value' in batch.bootstrap: bootstrap = batch.bootstrap agent_id = bootstrap['agent_id'] value_batch, _ = self.estimate_value(agent_id=agent_id, states=[bootstrap['state']], concats=[bootstrap['concat']], internal_state=bootstrap['internal_state']) bootstrap['value'] = value_batch[0] return self.compute_discounted_cumulative_reward(batch) def add_to_reward_prediction_buffer(self, batch): batch_size = batch.get_size(self.agents_set) if batch_size < 2: return batch_extrinsic_reward = batch.get_cumulative_reward(self.agents_set)[0] self.reward_prediction_buffer.put(batch=batch, type_id=1 if batch_extrinsic_reward != 0 else 0) # process batch only after sampling, for better perfomance def get_reward_prediction_tuple(self, batch): flat_states = [batch.get_action('states', agent_id, pos) for (agent_id,pos) in batch.step_generator(self.agents_set)] flat_rewards = [batch.get_action('rewards', agent_id, pos) for (agent_id,pos) in batch.step_generator(self.agents_set)] states_count = len(flat_states) length = min(3, states_count-1) start_idx = np.random.randint(states_count-length) if states_count > length else 0 reward_prediction_states = [flat_states[start_idx+i] for i in range(length)] reward_prediction_target = np.zeros((1,3)) target_reward = flat_rewards[start_idx+length][0] # use only extrinsic rewards if target_reward == 0: reward_prediction_target[0][0] = 1.0 # zero elif target_reward > 0: reward_prediction_target[0][1] = 1.0 # positive else: reward_prediction_target[0][2] = 1.0 # negative return reward_prediction_states, reward_prediction_target def add_to_replay_buffer(self, batch, batch_error): batch_size = batch.get_size(self.agents_set) if batch_size < 1: return batch_reward = batch.get_cumulative_reward(self.agents_set) batch_extrinsic_reward = batch_reward[0] batch_intrinsic_reward = batch_reward[1] batch_tot_reward = batch_extrinsic_reward + batch_intrinsic_reward if batch_tot_reward == 0 and flags.save_only_batches_with_reward: return if flags.replay_using_default_internal_state: batch.reset_internal_states() type_id = (1 if batch_intrinsic_reward > 0 else (2 if batch_extrinsic_reward > 0 else 0)) if flags.prioritized_replay: self.experience_buffer.put(batch=batch, priority=batch_tot_reward, type_id=type_id) else: self.experience_buffer.put(batch=batch, type_id=type_id) def replay_experience(self): if not self.experience_buffer.has_atleast(flags.replay_start): return n = np.random.poisson(flags.replay_ratio) for _ in range(n): old_batch = self.experience_buffer.sample() self.train(self.replay_value(old_batch) if flags.replay_value else old_batch) def process_batch(self, global_step): batch = self.compute_discounted_cumulative_reward(self.batch) # reward prediction if flags.predict_reward: self.add_to_reward_prediction_buffer(batch) # do it before training, this way there will be at least one batch in the reward_prediction_buffer if self.reward_prediction_buffer.is_empty(): return # cannot train without reward prediction, wait until reward_prediction_buffer is not empty # train batch_error = self.train(batch) # experience replay (after training!) if flags.replay_ratio > 0 and global_step > flags.replay_step: self.replay_experience() self.add_to_replay_buffer(batch, batch_error)
class Fdr(ContinualModel): NAME = 'fdr' COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual'] def __init__(self, backbone, loss, args, transform): super(Fdr, self).__init__(backbone, loss, args, transform) self.buffer = Buffer(self.args.buffer_size, self.device) self.current_task = 0 self.i = 0 self.soft = torch.nn.Softmax(dim=1) self.logsoft = torch.nn.LogSoftmax(dim=1) def end_task(self, dataset): self.current_task += 1 examples_per_task = self.args.buffer_size // self.current_task if self.current_task > 1: buf_x, buf_log, buf_tl = self.buffer.get_all_data() self.buffer.empty() for ttl in buf_tl.unique(): idx = (buf_tl == ttl) ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx] first = min(ex.shape[0], examples_per_task) self.buffer.add_data(examples=ex[:first], logits=log[:first], task_labels=tasklab[:first]) counter = 0 with torch.no_grad(): for i, data in enumerate(dataset.train_loader): inputs, labels, not_aug_inputs = data inputs = inputs.to(self.device) not_aug_inputs = not_aug_inputs.to(self.device) outputs = self.net(inputs) if examples_per_task - counter < 0: break self.buffer.add_data( examples=not_aug_inputs[:(examples_per_task - counter)], logits=outputs.data[:(examples_per_task - counter)], task_labels=(torch.ones(self.args.batch_size) * (self.current_task - 1))[:(examples_per_task - counter)]) counter += self.args.batch_size def observe(self, inputs, labels, not_aug_inputs): self.i += 1 self.opt.zero_grad() outputs = self.net(inputs) loss = self.loss(outputs, labels) loss.backward() self.opt.step() if not self.buffer.is_empty(): self.opt.zero_grad() buf_inputs, buf_logits, _ = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) buf_outputs = self.net(buf_inputs) loss = torch.norm( self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean() assert not torch.isnan(loss) loss.backward() self.opt.step() return loss.item()
def train_cl(train_set, test_set, model, loss, optimizer, device, config): """ :param train_set: Train set :param test_set: Test set :param model: PyTorch model :param loss: loss function :param optimizer: optimizer :param device: device cuda/cpu :param config: configuration """ name = "" # global_writer = SummaryWriter('./runs/continual/train/global/' + datetime.datetime.now().strftime('%m_%d_%H_%M')) global_writer = SummaryWriter('./runs/continual/train/global/' + name) buffer = Buffer(config['buffer_size'], device) accuracy = [] text = open("result_" + name + ".txt", "w") # TODO save results in a .txt file # Eval without training random_accuracy = evaluate_past(model, len(test_set) - 1, test_set, loss, device) text.write("Evaluation before training" + '\n') for a in random_accuracy: text.write(f"{a:.2f}% ") text.write('\n') for index, data_set in enumerate(train_set): model.train() print(f"----- DOMAIN {index} -----") print("Training model...") train_loader = DataLoader(data_set, batch_size=config['batch_size'], shuffle=False) for epoch in tqdm(range(config['epochs'])): epoch_loss = [] epoch_acc = [] for i, (x, y) in enumerate(train_loader): optimizer.zero_grad() inputs = x.to(device) labels = y.to(device) if not buffer.is_empty(): # Strategy 50/50 # From batch of 64 (dataloader) to 64 + 64 (dataloader + replay) buf_input, buf_label = buffer.get_data( config['batch_size']) inputs = torch.cat((inputs, torch.stack(buf_input))) labels = torch.cat((labels, torch.stack(buf_label))) y_pred = model(inputs) s_loss = loss(y_pred.squeeze(1), labels) acc = binary_accuracy(y_pred.squeeze(1), labels) # METRICHE INTERNE EPOCA epoch_loss.append(s_loss.item()) epoch_acc.append(acc.item()) s_loss.backward() optimizer.step() if epoch == 0: buffer.add_data(examples=x.to(device), labels=y.to(device)) global_writer.add_scalar('Train_global/Loss', statistics.mean(epoch_loss), epoch + (config['epochs'] * index)) global_writer.add_scalar('Train_global/Accuracy', statistics.mean(epoch_acc), epoch + (config['epochs'] * index)) # domain_writer.add_scalar(f'Train_D{index}/Loss', statistics.mean(epoch_loss), epoch) # domain_writer.add_scalar(f'Train_D{index}/Accuracy', statistics.mean(epoch_acc), epoch) if epoch % 100 == 0: print( f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} ' f'| Acc: {statistics.mean(epoch_acc):.5f}') # Last epoch (only for stats) if epoch == 499: print( f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} ' f'| Acc: {statistics.mean(epoch_acc):.5f}') # Test on domain just trained + old domains evaluation = evaluate_past(model, index, test_set, loss, device) accuracy.append(evaluation) text.write(f"Evaluation after domain {index}" + '\n') for a in evaluation: text.write(f"{a:.2f}% ") text.write('\n') if index != len(train_set) - 1: accuracy[index].append( evaluate_next(model, index, test_set, loss, device)) # Check buffer distribution buffer.check_distribution() # Compute transfer metrics backward = backward_transfer(accuracy) forward = forward_transfer(accuracy, random_accuracy) forget = forgetting(accuracy) print(f'Backward transfer: {backward}') # todo Sono in %? print(f'Forward transfer: {forward}') print(f'Forgetting: {forget}') text.write(f"Backward: {backward}\n") text.write(f"Forward: {forward}\n") text.write(f"Forgetting: {forget}\n") text.close()
class OCILFAST(ContinualModel): NAME = 'OCILFAST' COMPATIBILITY = ['class-il', 'task-il'] def __init__(self, net, loss, args, transform): super(OCILFAST, self).__init__(net, loss, args, transform) self.nets = [] self.c = [] self.threshold = [] self.nu = self.args.nu self.eta = self.args.eta self.eps = self.args.eps self.embedding_dim = self.args.embedding_dim self.weight_decay = self.args.weight_decay self.margin = self.args.margin self.current_task = 0 self.cpt = None self.nc = None self.eye = None self.buffer_size = self.args.buffer_size self.buffer = Buffer(self.args.buffer_size, self.device) self.nf = self.args.nf if self.args.dataset == 'seq-cifar10' or self.args.dataset == 'seq-mnist': self.input_offset = -0.5 elif self.args.dataset == 'seq-tinyimg': self.input_offset = 0 else: self.input_offset = 0 # 任务初始化 def begin_task(self, dataset): if self.cpt is None: self.cpt = dataset.N_CLASSES_PER_TASK self.nc = dataset.N_TASKS * self.cpt self.eye = torch.tril(torch.ones((self.nc, self.nc))).bool().to( self.device) # 下三角包括对角线为True,上三角为False,用于掩码 if len(self.nets) == 0: for i in range(self.nc): self.nets.append( get_backbone(self.net, self.embedding_dim, self.nc, self.nf).to(self.device)) self.c.append( torch.ones(self.embedding_dim, device=self.device)) self.current_task += 1 def train_model(self, dataset, train_loader): categories = list( range((self.current_task - 1) * self.cpt, (self.current_task) * self.cpt)) print('==========\t task: %d\t categories:' % self.current_task, categories, '\t==========') if self.args.print_file: print('==========\t task: %d\t categories:' % self.current_task, categories, '\t==========', file=self.args.print_file) for category in categories: losses = [] if category > 0: self.reset_train_loader(train_loader, category) for epoch in range(self.args.n_epochs): avg_loss, maxloss, posdist, negdist, gloloss = self.train_category( train_loader, category, epoch) losses.append(avg_loss) if epoch == 0 or (epoch + 1) % 5 == 0: print("epoch: %d\t task: %d \t category: %d \t loss: %f" % (epoch + 1, self.current_task, category, avg_loss)) if self.args.print_file: print( "epoch: %d\t task: %d \t category: %d \t loss: %f" % (epoch + 1, self.current_task, category, avg_loss), file=self.args.print_file) plt.figure(figsize=(20, 12)) ax = plt.subplot(2, 2, 1) ax.set_title('maxloss') plt.xlim((0, 2)) if maxloss is not None: try: sns.distplot(maxloss) except: pass ax = plt.subplot(2, 2, 2) ax.set_title('posdist') plt.xlim((0, 2)) try: sns.distplot(posdist) except: print(posdist) ax = plt.subplot(2, 2, 3) ax.set_title('negdist') plt.xlim((0, 2)) try: sns.distplot(negdist) except: print(negdist) ax = plt.subplot(2, 2, 4) ax.set_title('gloloss') plt.xlim((0, 2)) try: sns.distplot(gloloss) except: print(gloloss) plt.savefig("../" + self.args.img_dir + "/loss-cat%d-epoch%d.png" % (category, epoch)) plt.clf() x = list(range(len(losses))) plt.plot(x, losses) plt.savefig("../" + self.args.img_dir + "/loss-cat%d.png" % (category)) plt.clf() self.fill_buffer(train_loader) def reset_train_loader(self, train_loader, category): dataset = train_loader.dataset input = dataset.data loader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=False) inputs = [] targets = [] prev_dists = [] prev_categories = list(range(category)) print('prev_categories', prev_categories) if self.args.print_file: print('prev_categories', prev_categories, file=self.args.print_file) for i, data in enumerate(loader): input, target, _ = data _, prev_dist = self.predict(input, prev_categories) inputs.append(input.detach().cpu()) targets.append(target.detach().cpu()) prev_dists.append(prev_dist.detach().cpu()) inputs = torch.cat(inputs, dim=0) targets = torch.cat(targets, dim=0) prev_dists = torch.cat(prev_dists, dim=0) dataset.set_prevdist(prev_dists) def train_category(self, data_loader, category: int, epoch_id): self.init_center_c(data_loader, category) c = self.c[category] network = self.nets[category].to(self.device) network.train() optimizer = SGD(network.parameters(), lr=self.args.lr, weight_decay=self.weight_decay) avg_loss = 0.0 sample_num = 0 maxloss = [] posdist = [] negdist = [] gloloss = [] prev_categories = list(range(category)) for i, data in enumerate(data_loader): inputs, semi_targets, prev_dists = data inputs = inputs.to(self.device) semi_targets = semi_targets.to(self.device) prev_dists = prev_dists.to(self.device) if (not self.buffer.is_empty()) and self.args.buffer_size > 0: buf_inputs, buf_labels = self.buffer.get_data( self.args.minibatch_size, transform=self.transform) # print(buf_inputs[0]) inputs = torch.cat((inputs, buf_inputs)) semi_targets = torch.cat((semi_targets, buf_labels)) # Zero the network parameter gradients optimizer.zero_grad() # 注意网络的输入要减去0.5 outputs = network(inputs + self.input_offset) dists = torch.sum((outputs - c)**2, dim=1) pos_dist_loss = torch.relu(dists - self.args.r) if category > 0: max_scores = torch.relu(dists.view(-1, 1) - prev_dists) max_loss = torch.sum(max_scores, dim=1) * self.margin / category loss_pos = pos_dist_loss + max_loss loss_neg = self.eta * dists**-1 pos_max_loss = max_loss[semi_targets == category] maxloss.append(pos_max_loss.detach().cpu().data.numpy()) else: loss_pos = pos_dist_loss loss_neg = self.eta * dists**-1 losses = torch.where(semi_targets == category, loss_pos, loss_neg) gloloss.append(losses.detach().cpu().data.numpy()) loss = torch.mean(losses) loss.backward() optimizer.step() # 记录损失部分 pos_dist = pos_dist_loss[semi_targets == category] posdist.append(pos_dist.detach().cpu().data.numpy()) neg_dist = loss_neg[semi_targets != category] negdist.append(neg_dist.detach().cpu().data.numpy()) avg_loss += loss.item() sample_num += inputs.shape[0] # 旧类别只训练一次 if category < (self.current_task - 1) * self.cpt: break avg_loss /= sample_num if len(maxloss) > 0: maxloss = np.hstack(maxloss) else: maxloss = None posdist = np.hstack(posdist) negdist = np.hstack(negdist) gloloss = np.hstack(gloloss) return avg_loss, maxloss, posdist, negdist, gloloss def fill_buffer(self, train_loader): for data in train_loader: # get the inputs of the batch inputs, semi_targets, not_aug_inputs = data self.buffer.add_data(examples=not_aug_inputs, labels=semi_targets) def init_center_c(self, train_loader: DataLoader, category): """Initialize hypersphere center c as the mean from an initial forward pass on the data.""" n_samples = 0 c = 0 net = self.nets[category].to(self.device) net.eval() with torch.no_grad(): for data in train_loader: # get the inputs of the batch inputs, semi_targets, not_aug_inputs = data inputs = inputs.to(self.device) semi_targets = semi_targets.to(self.device) outputs = net(inputs + self.input_offset) outputs = outputs[semi_targets == category] # 取所有正样本来进行圆心初始化 # print(outputs) n_samples += outputs.shape[0] c += torch.sum(outputs, dim=0) c /= n_samples # If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights. c[(abs(c) < self.eps) & (c < 0)] = -self.eps c[(abs(c) < self.eps) & (c > 0)] = self.eps self.c[category] = c.to(self.device) def get_score(self, dist, category): score = 1 / (dist + 1e-6) return score def forward(self, x: torch.Tensor) -> torch.Tensor: categories = list(range(self.current_task * self.cpt)) return self.predict(x, categories)[0] def predict(self, inputs: torch.Tensor, categories): inputs = inputs.to(self.device) outcome, dists = [], [] with torch.no_grad(): for i in categories: net = self.nets[i] net.to(self.device) net.eval() c = self.c[i].to(self.device) pred = net(inputs + self.input_offset) dist = torch.sum((pred - c)**2, dim=1) scores = self.get_score(dist, i) outcome.append(scores.view(-1, 1)) dists.append(dist.view(-1, 1)) outcome = torch.cat(outcome, dim=1) dists = torch.cat(dists, dim=1) return outcome, dists