示例#1
0
文件: fdr.py 项目: yxue3357/mammoth
 def __init__(self, backbone, loss, args, transform):
     super(Fdr, self).__init__(backbone, loss, args, transform)
     self.buffer = Buffer(self.args.buffer_size, self.device)
     self.current_task = 0
     self.i = 0
     self.soft = torch.nn.Softmax(dim=1)
     self.logsoft = torch.nn.LogSoftmax(dim=1)
示例#2
0
文件: er.py 项目: yxue3357/mammoth
class Er(ContinualModel):
    NAME = 'er'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Er, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def observe(self, inputs, labels, not_aug_inputs):

        real_batch_size = inputs.shape[0]

        self.opt.zero_grad()
        if not self.buffer.is_empty():
            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            inputs = torch.cat((inputs, buf_inputs))
            labels = torch.cat((labels, buf_labels))

        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()

        self.buffer.add_data(examples=not_aug_inputs,
                             labels=labels[:real_batch_size])

        return loss.item()
示例#3
0
 def decode(self, buffer: Buffer, conn):
     if buffer.readable_length() == 0:
         return None
     msg = buffer.slice(0).decode()
     buffer.has_read(buffer.readable_length())
     return msg
     pass
示例#4
0
    def __init__(self, obs_size, act_size):
        """Constant Parameters"""
        self.lr_actor = 1e-4
        self.lr_critic = 1e-3
        self.w_decay = 1e-2  # L2 weight decay for Q
        self.to = 1e-3  # Soft target update
        self.buffer_size = 1e-6
        self.minibatch_size = 64
        self.mean = 0
        self.sigma = 1
        self.gemma = 0.99

        # Initializing networks
        self.actor = Actor(obs_size, act_size)
        self.actor_bar = Actor(obs_size, act_size)

        self.critic = Critic(obs_size, act_size)
        self.critic_bar = Critic(obs_size, act_size)

        # Make actor_bar and critic_bar with same weights
        hard_update(self.actor_bar, self.actor)
        hard_update(self.critic_bar, self.critic)

        # Initializing buffer
        self.buffer = Buffer(self.buffer_size)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.lr_critic,
                                                 weight_decay=self.w_decay)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.lr_actor)
        self.mse_loss = nn.MSELoss()
示例#5
0
class Der(ContinualModel):
    NAME = 'der'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Der, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def observe(self, inputs, labels, not_aug_inputs):

        self.opt.zero_grad()

        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)

        if not self.buffer.is_empty():
            buf_inputs, buf_logits = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            buf_outputs = self.net(buf_inputs)
            loss += self.args.alpha * F.mse_loss(buf_outputs, buf_logits)

        loss.backward()
        self.opt.step()
        self.buffer.add_data(examples=not_aug_inputs, logits=outputs.data)

        return loss.item()
示例#6
0
    def __init__(self, net, loss, args, transform):
        super(OCILFAST, self).__init__(net, loss, args, transform)

        self.nets = []
        self.c = []
        self.threshold = []

        self.nu = self.args.nu
        self.eta = self.args.eta
        self.eps = self.args.eps
        self.embedding_dim = self.args.embedding_dim
        self.weight_decay = self.args.weight_decay
        self.margin = self.args.margin

        self.current_task = 0
        self.cpt = None
        self.nc = None
        self.eye = None
        self.buffer_size = self.args.buffer_size
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.nf = self.args.nf

        if self.args.dataset == 'seq-cifar10' or self.args.dataset == 'seq-mnist':
            self.input_offset = -0.5
        elif self.args.dataset == 'seq-tinyimg':
            self.input_offset = 0
        else:
            self.input_offset = 0
示例#7
0
 def __init__(self, sock: socket.socket, codec: Codec, processor):
     self._socket = sock
     self._codec = codec
     self._processor = processor
     self._buffer = Buffer()
     self._stop = False
     self._queue = Queue()
     self._last_active_time = time.time()
     self.address = self._socket.getpeername()
示例#8
0
    def __init__(self, backbone, loss, args, transform):
        super(AGem, self).__init__(backbone, loss, args, transform)

        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
示例#9
0
class Mer(ContinualModel):
    NAME = 'mer'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Mer, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)

    def draw_batches(self, inp, lab):
        batches = []
        for i in range(self.args.batch_num):
            if not self.buffer.is_empty():
                buf_inputs, buf_labels = self.buffer.get_data(
                    self.args.minibatch_size, transform=self.transform)
                inputs = torch.cat((buf_inputs, inp.unsqueeze(0)))
                labels = torch.cat(
                    (buf_labels, torch.tensor([lab]).to(self.device)))
                batches.append((inputs, labels))
            else:
                batches.append(
                    (inp.unsqueeze(0),
                     torch.tensor([lab]).unsqueeze(0).to(self.device)))
        return batches

    def observe(self, inputs, labels, not_aug_inputs):

        batches = self.draw_batches(inputs, labels)
        theta_A0 = self.net.get_params().data.clone()

        for i in range(self.args.batch_num):
            theta_Wi0 = self.net.get_params().data.clone()

            batch_inputs, batch_labels = batches[i]

            # within-batch step
            self.opt.zero_grad()
            outputs = self.net(batch_inputs)
            loss = self.loss(outputs, batch_labels.squeeze(-1))
            loss.backward()
            self.opt.step()

            # within batch reptile meta-update
            new_params = theta_Wi0 + self.args.beta * (self.net.get_params() -
                                                       theta_Wi0)
            self.net.set_params(new_params)

        self.buffer.add_data(examples=not_aug_inputs.unsqueeze(0),
                             labels=labels)

        # across batch reptile meta-update
        new_new_params = theta_A0 + self.args.gamma * (self.net.get_params() -
                                                       theta_A0)
        self.net.set_params(new_new_params)

        return loss.item()
示例#10
0
 def decode(self, buffer: Buffer, conn):
     readable_length = buffer.readable_length()
     if readable_length <= RPC_HEADER_LEN:
         return None
     need_length = int.from_bytes(buffer.slice(RPC_HEADER_LEN),
                                  'little') + RPC_HEADER_LEN
     if need_length > readable_length:
         return None
     data = buffer.slice(need_length)[RPC_HEADER_LEN:]
     buffer.has_read(need_length)
     return codec_decode(data)
示例#11
0
    def __init__(self, backbone, loss, args, transform):
        super(ICarl, self).__init__(backbone, loss, args, transform)
        self.dataset = get_dataset(args)

        # Instantiate buffers
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.eye = torch.eye(self.dataset.N_CLASSES_PER_TASK *
                             self.dataset.N_TASKS).to(self.device)

        self.class_means = None
        self.old_net = None
        self.current_task = 0
示例#12
0
    def __init__(self, backbone, loss, args, transform):
        super(Gem, self).__init__(backbone, loss, args, transform)
        self.current_task = 0
        self.buffer = Buffer(self.args.buffer_size, self.device)

        # Allocate temporary synaptic memory
        self.grad_dims = []
        for pp in self.parameters():
            self.grad_dims.append(pp.data.numel())

        self.grads_cs = []
        self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device)
示例#13
0
class AGem(ContinualModel):
    NAME = 'agem'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(AGem, self).__init__(backbone, loss, args, transform)

        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.transform = transform if self.args.iba else None

    def end_task(self, dataset):
        samples_per_task = self.args.buffer_size // dataset.N_TASKS
        loader = dataset.not_aug_dataloader(self.args, samples_per_task)
        cur_x, cur_y = next(iter(loader))[:2]
        self.buffer.add_data(examples=cur_x.to(self.device),
                             labels=cur_y.to(self.device))

    def observe(self, inputs, labels, not_aug_inputs):

        self.zero_grad()
        p = self.net.forward(inputs)
        loss = self.loss(p, labels)
        loss.backward()

        if not self.buffer.is_empty():
            store_grad(self.parameters, self.grad_xy, self.grad_dims)

            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            self.net.zero_grad()
            buf_outputs = self.net.forward(buf_inputs)
            penalty = self.loss(buf_outputs, buf_labels)
            penalty.backward()
            store_grad(self.parameters, self.grad_er, self.grad_dims)

            dot_prod = torch.dot(self.grad_xy, self.grad_er)
            if dot_prod.item() < 0:
                g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
                overwrite_grad(self.parameters, g_tilde, self.grad_dims)
            else:
                overwrite_grad(self.parameters, self.grad_xy, self.grad_dims)

        self.opt.step()

        return loss.item()
示例#14
0
 def __init__(self, backbone, loss, args, transform):
     super(HAL, self).__init__(backbone, loss, args, transform)
     self.task_number = 0
     self.buffer = Buffer(self.args.buffer_size,
                          self.device,
                          get_dataset(args).N_TASKS,
                          mode='ring')
     self.hal_lambda = args.hal_lambda
     self.beta = args.beta
     self.gamma = args.gamma
     self.anchor_optimization_steps = 100
     self.finetuning_epochs = 1
     self.dataset = get_dataset(args)
     self.spare_model = self.dataset.get_backbone()
     self.spare_model.to(self.device)
     self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr)
示例#15
0
class AGemr(ContinualModel):
    NAME = 'agem_r'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(AGemr, self).__init__(backbone, loss, args, transform)

        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.grad_dims = []
        for param in self.parameters():
            self.grad_dims.append(param.data.numel())
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.current_task = 0

    def observe(self, inputs, labels, not_aug_inputs):
        self.zero_grad()
        p = self.net.forward(inputs)
        loss = self.loss(p, labels)
        loss.backward()

        if not self.buffer.is_empty():
            store_grad(self.parameters, self.grad_xy, self.grad_dims)

            buf_inputs, buf_labels = self.buffer.get_data(self.args.minibatch_size)
            self.net.zero_grad()
            buf_outputs = self.net.forward(buf_inputs)
            penalty = self.loss(buf_outputs, buf_labels)
            penalty.backward()
            store_grad(self.parameters, self.grad_er, self.grad_dims)

            dot_prod = torch.dot(self.grad_xy, self.grad_er)
            if dot_prod.item() < 0:
                g_tilde = project(gxy=self.grad_xy, ger=self.grad_er)
                overwrite_grad(self.parameters, g_tilde, self.grad_dims)
            else:
                overwrite_grad(self.parameters, self.grad_xy, self.grad_dims)

        self.opt.step()

        self.buffer.add_data(examples=not_aug_inputs, labels=labels)

        return loss.item()
	def __init__(self, session, device, id, action_shape, state_shape, concat_size=0, global_network=None, training=True):
		self.training = training
		self.session = session
		self.id = id
		self.device = device
		self.state_shape = state_shape
		self.set_model_size()
		if self.training:
			self.global_network = global_network
			# Gradient optimizer and clip range
			if not self.is_global_network():
				self.clip = self.global_network.clip
			else:
				self.initialize_gradient_optimizer()
			# Build agents
			self.model_list = []
			self.build_agents(state_shape=state_shape, action_shape=action_shape, concat_size=concat_size)
			# Build experience buffer
			if flags.replay_ratio > 0:
				if flags.prioritized_replay:
					self.experience_buffer = PrioritizedBuffer(size=flags.replay_buffer_size)
					# self.beta_schedule = LinearSchedule(flags.max_time_step, initial_p=0.4, final_p=1.0)
				else:
					self.experience_buffer = Buffer(size=flags.replay_buffer_size)
			if flags.predict_reward:
				self.reward_prediction_buffer = Buffer(size=flags.reward_prediction_buffer_size)
			# Bind optimizer to global
			if not self.is_global_network():
				self.bind_to_global(self.global_network)
			# Count based exploration	
			if flags.use_count_based_exploration_reward:
				self.projection = None
				self.projection_dataset = []
			if flags.print_loss:
				self._loss_list = [{} for _ in range(self.model_size)]
		else:
			self.global_network = None
			self.model_list = global_network.model_list
		# Statistics
		self._model_usage_list = deque()
 def build_agents(self, state_shape, action_shape, concat_size):
     # partitioner
     if self.is_global_network():
         self.buffer = Buffer(size=flags.partitioner_dataset_size)
         self.partitioner = KMeans(n_clusters=self.model_size)
     self.partitioner_trained = False
     # agents
     self.model_list = []
     for i in range(self.model_size):
         agent = eval(flags.network + "_Network")(
             id="{0}_{1}".format(self.id, i),
             device=self.device,
             session=self.session,
             state_shape=state_shape,
             action_shape=action_shape,
             concat_size=concat_size,
             clip=self.clip[i],
             predict_reward=flags.predict_reward,
             training=self.training)
         self.model_list.append(agent)
     # bind partition nets to training net
     if self.is_global_network():
         self.bind_to_training_net()
         self.lock = threading.Lock()
示例#18
0
    def _key_exchange(self, conn):
        """Initiate a DH key exchange (send/receive public keys)"""
        public_key = self._keys.get_public().to_hex()

        # Create the message object with server public key
        msg = Message(Message.KEY_EXCHG, public_key=public_key)

        # Send the message buffer
        conn.send(msg.buffer)

        # Read the client public key
        raw = conn.recv(self.BUFFER_SIZE)
        recv_msg = Message.from_buffer(raw)

        assert recv_msg['code'] == Message.KEY_EXCHG, \
            'Unexpected message code during key exchange: %d' % recv_msg['code']

        # Create the shared secret
        buf = Buffer.from_hex(recv_msg['public_key'])
        self._session_secret = self._keys.session_key(buf)
示例#19
0
class HAL(ContinualModel):
    NAME = 'hal'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(HAL, self).__init__(backbone, loss, args, transform)
        self.task_number = 0
        self.buffer = Buffer(self.args.buffer_size,
                             self.device,
                             get_dataset(args).N_TASKS,
                             mode='ring')
        self.hal_lambda = args.hal_lambda
        self.beta = args.beta
        self.gamma = args.gamma
        self.anchor_optimization_steps = 100
        self.finetuning_epochs = 1
        self.dataset = get_dataset(args)
        self.spare_model = self.dataset.get_backbone()
        self.spare_model.to(self.device)
        self.spare_opt = SGD(self.spare_model.parameters(), lr=self.args.lr)

    def end_task(self, dataset):
        self.task_number += 1
        # ring buffer mgmt (if we are not loading
        if self.task_number > self.buffer.task_number:
            self.buffer.num_seen_examples = 0
            self.buffer.task_number = self.task_number
        # get anchors (provided that we are not loading the model
        if len(self.anchors) < self.task_number * dataset.N_CLASSES_PER_TASK:
            self.get_anchors(dataset)
            del self.phi

    def get_anchors(self, dataset):
        theta_t = self.net.get_params().detach().clone()
        self.spare_model.set_params(theta_t)

        # fine tune on memory buffer
        for _ in range(self.finetuning_epochs):
            inputs, labels = self.buffer.get_data(self.args.batch_size,
                                                  transform=self.transform)
            self.spare_opt.zero_grad()
            out = self.spare_model(inputs)
            loss = self.loss(out, labels)
            loss.backward()
            self.spare_opt.step()

        theta_m = self.spare_model.get_params().detach().clone()

        classes_for_this_task = np.unique(dataset.train_loader.dataset.targets)

        for a_class in classes_for_this_task:
            e_t = torch.rand(self.input_shape,
                             requires_grad=True,
                             device=self.device)
            e_t_opt = SGD([e_t], lr=self.args.lr)
            print(file=sys.stderr)
            for i in range(self.anchor_optimization_steps):
                e_t_opt.zero_grad()
                cum_loss = 0

                self.spare_opt.zero_grad()
                self.spare_model.set_params(theta_m.detach().clone())
                loss = -torch.sum(
                    self.loss(self.spare_model(e_t.unsqueeze(0)),
                              torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_opt.zero_grad()
                self.spare_model.set_params(theta_t.detach().clone())
                loss = torch.sum(
                    self.loss(self.spare_model(e_t.unsqueeze(0)),
                              torch.tensor([a_class]).to(self.device)))
                loss.backward()
                cum_loss += loss.item()

                self.spare_opt.zero_grad()
                loss = torch.sum(self.gamma *
                                 (self.spare_model.features(e_t.unsqueeze(0)) -
                                  self.phi)**2)
                assert not self.phi.requires_grad
                loss.backward()
                cum_loss += loss.item()

                e_t_opt.step()

            e_t = e_t.detach()
            e_t.requires_grad = False
            self.anchors = torch.cat((self.anchors, e_t.unsqueeze(0)))
            del e_t
            print('Total anchors:', len(self.anchors), file=sys.stderr)

        self.spare_model.zero_grad()

    def observe(self, inputs, labels, not_aug_inputs):
        real_batch_size = inputs.shape[0]
        if not hasattr(self, 'input_shape'):
            self.input_shape = inputs.shape[1:]
        if not hasattr(self, 'anchors'):
            self.anchors = torch.zeros(tuple([0] + list(self.input_shape))).to(
                self.device)
        if not hasattr(self, 'phi'):
            print('Building phi', file=sys.stderr)
            with torch.no_grad():
                self.phi = torch.zeros_like(self.net.features(
                    inputs[0].unsqueeze(0)),
                                            requires_grad=False)
            assert not self.phi.requires_grad

        if not self.buffer.is_empty():
            buf_inputs, buf_labels = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            inputs = torch.cat((inputs, buf_inputs))
            labels = torch.cat((labels, buf_labels))

        old_weights = self.net.get_params().detach().clone()

        self.opt.zero_grad()
        outputs = self.net(inputs)

        k = self.task_number

        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()

        first_loss = 0

        assert len(self.anchors) == self.dataset.N_CLASSES_PER_TASK * k

        if len(self.anchors) > 0:
            first_loss = loss.item()
            with torch.no_grad():
                pred_anchors = self.net(self.anchors)

            self.net.set_params(old_weights)
            pred_anchors -= self.net(self.anchors)
            loss = self.hal_lambda * (pred_anchors**2).mean()
            loss.backward()
            self.opt.step()

        with torch.no_grad():
            self.phi = self.beta * self.phi + (
                1 - self.beta) * self.net.features(
                    inputs[:real_batch_size]).mean(0)

        self.buffer.add_data(examples=not_aug_inputs,
                             labels=labels[:real_batch_size])

        return first_loss + loss.item()
示例#20
0
class DDPG(object):
    """DDPG Algorithm"""
    def __init__(self, obs_size, act_size):
        """Constant Parameters"""
        self.lr_actor = 1e-4
        self.lr_critic = 1e-3
        self.w_decay = 1e-2  # L2 weight decay for Q
        self.to = 1e-3  # Soft target update
        self.buffer_size = 1e-6
        self.minibatch_size = 64
        self.mean = 0
        self.sigma = 1
        self.gemma = 0.99

        # Initializing networks
        self.actor = Actor(obs_size, act_size)
        self.actor_bar = Actor(obs_size, act_size)

        self.critic = Critic(obs_size, act_size)
        self.critic_bar = Critic(obs_size, act_size)

        # Make actor_bar and critic_bar with same weights
        hard_update(self.actor_bar, self.actor)
        hard_update(self.critic_bar, self.critic)

        # Initializing buffer
        self.buffer = Buffer(self.buffer_size)

        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.lr_critic,
                                                 weight_decay=self.w_decay)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.lr_actor)
        self.mse_loss = nn.MSELoss()

    def exploration_policy(self, shape):
        """ 
    Gaussian noise
    """
        return torch.normal(self.mean, self.sigma, size=(shape[0], shape[1]))

    def action_taking(self, is_explore, state):
        """
    Select the actions
    """
        a = self.actor.forward(state)
        if is_explore:
            return a + self.exploration_policy(a.shape)
        return a

    def converter(self, arr, to):
        """
    Convert a = (numpy to torch.tensor) or ~a
    """
        if to == "Torch":
            return torch.from_numpy(arr)
        return arr.detach().numpy()

    def store_data(self, data):
        """
    Store data on the buffer
    """
        self.buffer.store(data)

    def train_algorithm(self):
        """Train the algorithm"""
        # Set of mini-batch
        data = self.buffer.get(self.minibatch_size)
        """
    data[0] = one sample
    data[0][0] = s : dim (1, obs_size)
    data[0][1] = a : dim (1, act_size)
    data[0][2] = r : dim (1, 1)
    """
        for sample in data:
            action = self.actor_bar.forward(sample[3])
            y_i = sample[2] + self.gemma * self.critic_bar.forward(
                sample[3], action)
            q = self.critic.forward(sample[0], sample[1])

            with torch.autograd.set_detect_anomaly(True):
                # Critic update
                self.critic_optimizer.zero_grad()
                output = self.mse_loss(q, y_i.detach())
                output.backward()
                self.critic_optimizer.step()

                # Actor update
                self.actor_optimizer.zero_grad()
                act = self.actor.forward(sample[0])
                q_value = self.critic.forward(sample[0], act)
                policy_loss = q_value.mean()
                policy_loss.backward()
                self.actor_optimizer.step()

            # Update the target networks
            for target_param, param in zip(self.critic_bar.parameters(),
                                           self.critic.parameters()):
                target_param.data.copy_(self.to * param.data +
                                        (1.0 - self.to) * target_param.data)

            for target_param, param in zip(self.actor_bar.parameters(),
                                           self.actor.parameters()):
                target_param.data.copy_(self.to * param.data +
                                        (1.0 - self.to) * target_param.data)
示例#21
0
class ICarl(ContinualModel):
    NAME = 'icarl'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(ICarl, self).__init__(backbone, loss, args, transform)
        self.dataset = get_dataset(args)

        # Instantiate buffers
        self.buffer = Buffer(self.args.buffer_size, self.device)

        if type(self.dataset.N_CLASSES_PER_TASK) == list:
            nc = int(np.sum(self.dataset.N_CLASSES_PER_TASK))
        else:
            nc = self.dataset.N_CLASSES_PER_TASK * self.dataset.N_TASKS
        self.eye = torch.eye(nc).to(self.device)

        self.class_means = None
        self.old_net = None
        self.current_task = 0

    def forward(self, x):
        if self.class_means is None:
            with torch.no_grad():
                self.compute_class_means()

        feats = self.net.features(x)
        feats = feats.unsqueeze(1)

        pred = (self.class_means.unsqueeze(0) - feats).pow(2).sum(2)
        return -pred

    def observe(self, inputs, labels, not_aug_inputs, logits=None):
        labels = labels.long()
        if not hasattr(self, 'classes_so_far'):
            self.register_buffer('classes_so_far', labels.unique().to('cpu'))
        else:
            self.register_buffer(
                'classes_so_far',
                torch.cat((self.classes_so_far, labels.to('cpu'))).unique())

        self.class_means = None
        if self.current_task > 0:
            with torch.no_grad():
                logits = torch.sigmoid(self.old_net(inputs))
        self.opt.zero_grad()
        loss = self.get_loss(inputs, labels, self.current_task, logits)
        loss.backward()

        self.opt.step()

        return loss.item()

    @staticmethod
    def binary_cross_entropy(pred, y):
        return -(pred.log() * y + (1 - y) * (1 - pred).log()).mean()

    def get_loss(self, inputs: torch.Tensor, labels: torch.Tensor,
                 task_idx: int, logits: torch.Tensor) -> torch.Tensor:
        """
        Computes the loss tensor.
        :param inputs: the images to be fed to the network
        :param labels: the ground-truth labels
        :param task_idx: the task index
        :return: the differentiable loss value
        """
        if type(self.dataset.N_CLASSES_PER_TASK) == list:
            pc = int(np.sum(self.dataset.N_CLASSES_PER_TASK[:task_idx]))
            ac = int(np.sum(self.dataset.N_CLASSES_PER_TASK[:task_idx + 1]))
        else:
            pc = task_idx * self.dataset.N_CLASSES_PER_TASK
            ac = (task_idx + 1) * self.dataset.N_CLASSES_PER_TASK

        outputs = self.net(inputs)[:, :ac]
        if task_idx == 0:
            # Compute loss on the current task
            targets = self.eye[labels][:, :ac]
            loss = F.binary_cross_entropy_with_logits(outputs, targets)
            assert loss >= 0
        else:
            targets = self.eye[labels][:, pc:ac]
            comb_targets = torch.cat((logits[:, :pc], targets), dim=1)
            loss = F.binary_cross_entropy_with_logits(outputs, comb_targets)
            assert loss >= 0

        if self.args.wd_reg:
            loss += self.args.wd_reg * torch.sum(self.net.get_params()**2)

        return loss

    def begin_task(self, dataset):
        denorm = (lambda x: x) if dataset.get_denormalization_transform()\
                                is None else dataset.get_denormalization_transform()
        if self.current_task > 0:
            if self.args.dataset != 'seq-core50':
                dataset.train_loader.dataset.targets = np.concatenate([
                    dataset.train_loader.dataset.targets,
                    self.buffer.labels.cpu().numpy()
                    [:self.buffer.num_seen_examples]
                ])
                if type(dataset.train_loader.dataset.data) == torch.Tensor:
                    dataset.train_loader.dataset.data = torch.cat([
                        dataset.train_loader.dataset.data,
                        torch.stack([
                            denorm(self.buffer.examples[i].type(
                                torch.uint8).cpu())
                            for i in range(self.buffer.num_seen_examples)
                        ]).squeeze(1)
                    ])
                else:
                    dataset.train_loader.dataset.data = np.concatenate([
                        dataset.train_loader.dataset.data,
                        torch.stack([
                            (denorm(self.buffer.examples[i] * 255).type(
                                torch.uint8).cpu())
                            for i in range(self.buffer.num_seen_examples)
                        ]).numpy().swapaxes(1, 3)
                    ])
            else:

                print(
                    torch.stack([(denorm(self.buffer.examples[i]).cpu())
                                 for i in range(self.buffer.num_seen_examples)
                                 ]).numpy().shape)
                dataset.train_loader.dataset.add_more_data(
                    more_targets=self.buffer.labels.cpu().numpy()
                    [:self.buffer.num_seen_examples],
                    more_data=torch.stack([
                        (denorm(self.buffer.examples[i]).cpu())
                        for i in range(self.buffer.num_seen_examples)
                    ]).numpy().swapaxes(1, 3))

    def end_task(self, dataset) -> None:
        self.old_net = deepcopy(self.net.eval())
        self.net.train()
        with torch.no_grad():
            self.fill_buffer(self.buffer, dataset, self.current_task)
        self.current_task += 1
        self.class_means = None

    def compute_class_means(self) -> None:
        """
        Computes a vector representing mean features for each class.
        """
        # This function caches class means
        transform = self.dataset.get_normalization_transform()
        class_means = []
        examples, labels, _ = self.buffer.get_all_data(transform)
        for _y in self.classes_so_far:
            x_buf = torch.stack([
                examples[i] for i in range(0, len(examples))
                if labels[i].cpu() == _y
            ]).to(self.device)

            class_means.append(self.net.features(x_buf).mean(0))
        self.class_means = torch.stack(class_means)

    def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int) -> None:
        """
        Adds examples from the current task to the memory buffer
        by means of the herding strategy.
        :param mem_buffer: the memory buffer
        :param dataset: the dataset from which take the examples
        :param t_idx: the task index
        """

        mode = self.net.training
        self.net.eval()
        samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far)

        if t_idx > 0:
            # 1) First, subsample prior classes
            buf_x, buf_y, buf_l = self.buffer.get_all_data()

            mem_buffer.empty()
            for _y in buf_y.unique():
                idx = (buf_y == _y)
                _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx]
                mem_buffer.add_data(examples=_y_x[:samples_per_class],
                                    labels=_y_y[:samples_per_class],
                                    logits=_y_l[:samples_per_class])

        # 2) Then, fill with current tasks
        loader = dataset.not_aug_dataloader(self.args, self.args.batch_size)

        # 2.1 Extract all features
        a_x, a_y, a_f, a_l = [], [], [], []
        for x, y, not_norm_x in loader:
            x, y, not_norm_x = (a.to(self.device) for a in [x, y, not_norm_x])
            a_x.append(not_norm_x.to('cpu'))
            a_y.append(y.to('cpu'))

            feats = self.net.features(x)
            a_f.append(feats.cpu())
            a_l.append(torch.sigmoid(self.net.classifier(feats)).cpu())
        a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat(
            a_f), torch.cat(a_l)

        # 2.2 Compute class means
        for _y in a_y.unique():
            idx = (a_y == _y)
            _x, _y, _l = a_x[idx], a_y[idx], a_l[idx]
            feats = a_f[idx]
            mean_feat = feats.mean(0, keepdim=True)

            running_sum = torch.zeros_like(mean_feat)
            i = 0
            while i < samples_per_class and i < feats.shape[0]:
                cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1)

                idx_min = cost.argmin().item()

                mem_buffer.add_data(
                    examples=_x[idx_min:idx_min + 1].to(self.device),
                    labels=_y[idx_min:idx_min + 1].to(self.device),
                    logits=_l[idx_min:idx_min + 1].to(self.device))

                running_sum += feats[idx_min:idx_min + 1]
                feats[idx_min] = feats[idx_min] + 1e6
                i += 1

        assert len(mem_buffer.examples) <= mem_buffer.buffer_size

        self.net.train(mode)
示例#22
0
def train_cl(train_set, test_set, model, loss, optimizer, device, config):
    """
    :param train_set: Train set
    :param test_set: Test set
    :param model: PyTorch model
    :param loss: loss function
    :param optimizer: optimizer
    :param device: device cuda/cpu
    :param config: configuration
    """

    name = ""
    # global_writer = SummaryWriter('./runs/continual/train/global/' + datetime.datetime.now().strftime('%m_%d_%H_%M'))
    global_writer = SummaryWriter('./runs/continual/train/global/' + name)
    buffer = Buffer(config['buffer_size'], device)
    accuracy = []
    text = open("result_" + name + ".txt",
                "w")  # TODO save results in a .txt file

    # Eval without training
    random_accuracy = evaluate_past(model,
                                    len(test_set) - 1, test_set, loss, device)
    text.write("Evaluation before training" + '\n')
    for a in random_accuracy:
        text.write(f"{a:.2f}% ")
    text.write('\n')

    for index, data_set in enumerate(train_set):
        model.train()
        print(f"----- DOMAIN {index} -----")
        print("Training model...")
        train_loader = DataLoader(data_set,
                                  batch_size=config['batch_size'],
                                  shuffle=False)

        for epoch in tqdm(range(config['epochs'])):
            epoch_loss = []
            epoch_acc = []
            for i, (x, y) in enumerate(train_loader):
                optimizer.zero_grad()

                inputs = x.to(device)
                labels = y.to(device)
                if not buffer.is_empty():
                    # Strategy 50/50
                    # From batch of 64 (dataloader) to 64 + 64 (dataloader + replay)
                    buf_input, buf_label = buffer.get_data(
                        config['batch_size'])
                    inputs = torch.cat((inputs, torch.stack(buf_input)))
                    labels = torch.cat((labels, torch.stack(buf_label)))

                y_pred = model(inputs)
                s_loss = loss(y_pred.squeeze(1), labels)
                acc = binary_accuracy(y_pred.squeeze(1), labels)
                # METRICHE INTERNE EPOCA
                epoch_loss.append(s_loss.item())
                epoch_acc.append(acc.item())

                s_loss.backward()
                optimizer.step()

                if epoch == 0:
                    buffer.add_data(examples=x.to(device), labels=y.to(device))

            global_writer.add_scalar('Train_global/Loss',
                                     statistics.mean(epoch_loss),
                                     epoch + (config['epochs'] * index))
            global_writer.add_scalar('Train_global/Accuracy',
                                     statistics.mean(epoch_acc),
                                     epoch + (config['epochs'] * index))

            # domain_writer.add_scalar(f'Train_D{index}/Loss', statistics.mean(epoch_loss), epoch)
            # domain_writer.add_scalar(f'Train_D{index}/Accuracy', statistics.mean(epoch_acc), epoch)

            if epoch % 100 == 0:
                print(
                    f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} '
                    f'| Acc: {statistics.mean(epoch_acc):.5f}')

            # Last epoch (only for stats)
            if epoch == 499:
                print(
                    f'\nEpoch {epoch:03}/{config["epochs"]} | Loss: {statistics.mean(epoch_loss):.5f} '
                    f'| Acc: {statistics.mean(epoch_acc):.5f}')

        # Test on domain just trained + old domains
        evaluation = evaluate_past(model, index, test_set, loss, device)
        accuracy.append(evaluation)
        text.write(f"Evaluation after domain {index}" + '\n')
        for a in evaluation:
            text.write(f"{a:.2f}% ")
        text.write('\n')

        if index != len(train_set) - 1:
            accuracy[index].append(
                evaluate_next(model, index, test_set, loss, device))

    # Check buffer distribution
    buffer.check_distribution()

    # Compute transfer metrics
    backward = backward_transfer(accuracy)
    forward = forward_transfer(accuracy, random_accuracy)
    forget = forgetting(accuracy)
    print(f'Backward transfer: {backward}')  # todo Sono in %?
    print(f'Forward transfer: {forward}')
    print(f'Forgetting: {forget}')

    text.write(f"Backward: {backward}\n")
    text.write(f"Forward: {forward}\n")
    text.write(f"Forgetting: {forget}\n")
    text.close()
class KMeansPartitioner(BasicManager):
    def set_model_size(self):
        self.model_size = flags.partition_count  # manager output size
        if self.model_size < 2:
            self.model_size = 2
        self.agents_set = set(range(self.model_size))

    def build_agents(self, state_shape, action_shape, concat_size):
        # partitioner
        if self.is_global_network():
            self.buffer = Buffer(size=flags.partitioner_dataset_size)
            self.partitioner = KMeans(n_clusters=self.model_size)
        self.partitioner_trained = False
        # agents
        self.model_list = []
        for i in range(self.model_size):
            agent = eval(flags.network + "_Network")(
                id="{0}_{1}".format(self.id, i),
                device=self.device,
                session=self.session,
                state_shape=state_shape,
                action_shape=action_shape,
                concat_size=concat_size,
                clip=self.clip[i],
                predict_reward=flags.predict_reward,
                training=self.training)
            self.model_list.append(agent)
        # bind partition nets to training net
        if self.is_global_network():
            self.bind_to_training_net()
            self.lock = threading.Lock()

    def bind_to_training_net(self):
        self.sync_list = []
        training_net = self.get_model(0)
        for i in range(1, self.model_size):
            partition_net = self.get_model(i)
            self.sync_list.append(partition_net.bind_sync(
                training_net))  # for syncing local network with global one

    def sync_with_training_net(self):
        for i in range(1, self.model_size):
            self.model_list[i].sync(self.sync_list[i - 1])

    def get_state_partition(self, state):
        id = self.partitioner.predict([state.flatten()])[0]
        # print(self.id, " ", id)
        self.add_to_statistics(id)
        return id

    def query_partitioner(self, step):
        return self.partitioner_trained and step % flags.partitioner_granularity == 0

    def act(self, act_function, state, concat=None):
        if self.query_partitioner(self.step):
            self.agent_id = self.get_state_partition(state)
        return super().act(act_function, state, concat)

    def populate_partitioner(self, states):
        # assert self.is_global_network(), 'only global network can populate partitioner'
        with self.lock:
            if not self.partitioner_trained:
                for i in range(0, len(states), flags.partitioner_granularity):
                    state = states[i]
                    self.buffer.put(batch=state.flatten())
                    if self.buffer.is_full():
                        print("Buffer is full, starting partitioner training")
                        self.partitioner.fit(self.buffer.get_batches())
                        print("Partitioner trained")
                        self.partitioner_trained = True
                        print("Syncing with training net")
                        self.sync_with_training_net()
                        print("Cleaning buffer")
                        self.buffer.clean()

    def bootstrap(self, state, concat=None):
        if self.query_partitioner(self.step):
            self.agent_id = self.get_state_partition(state)
        super().bootstrap(state, concat)
        # populate partitioner training set
        if not self.partitioner_trained and not self.is_global_network():
            self.global_network.populate_partitioner(
                states=self.batch.states[self.agent_id]
            )  # if the partitioner is not trained, al the states are associated to the current agent
            self.partitioner_trained = self.global_network.partitioner_trained
            if self.partitioner_trained:
                self.partitioner = copy.deepcopy(
                    self.global_network.partitioner)
class BasicManager(object):
	
	def __init__(self, session, device, id, action_shape, state_shape, concat_size=0, global_network=None, training=True):
		self.training = training
		self.session = session
		self.id = id
		self.device = device
		self.state_shape = state_shape
		self.set_model_size()
		if self.training:
			self.global_network = global_network
			# Gradient optimizer and clip range
			if not self.is_global_network():
				self.clip = self.global_network.clip
			else:
				self.initialize_gradient_optimizer()
			# Build agents
			self.model_list = []
			self.build_agents(state_shape=state_shape, action_shape=action_shape, concat_size=concat_size)
			# Build experience buffer
			if flags.replay_ratio > 0:
				if flags.prioritized_replay:
					self.experience_buffer = PrioritizedBuffer(size=flags.replay_buffer_size)
					# self.beta_schedule = LinearSchedule(flags.max_time_step, initial_p=0.4, final_p=1.0)
				else:
					self.experience_buffer = Buffer(size=flags.replay_buffer_size)
			if flags.predict_reward:
				self.reward_prediction_buffer = Buffer(size=flags.reward_prediction_buffer_size)
			# Bind optimizer to global
			if not self.is_global_network():
				self.bind_to_global(self.global_network)
			# Count based exploration	
			if flags.use_count_based_exploration_reward:
				self.projection = None
				self.projection_dataset = []
			if flags.print_loss:
				self._loss_list = [{} for _ in range(self.model_size)]
		else:
			self.global_network = None
			self.model_list = global_network.model_list
		# Statistics
		self._model_usage_list = deque()
			
	def is_global_network(self):
		return self.global_network is None
			
	def set_model_size(self):
		self.model_size = 1
		self.agents_set = set([0])
			
	def build_agents(self, state_shape, action_shape, concat_size):
		agent=eval('{}_Network'.format(flags.network))(
			session=self.session, 
			id='{0}_{1}'.format(self.id, 0), 
			device=self.device, 
			state_shape=state_shape, 
			action_shape=action_shape, 
			concat_size=concat_size, 
			clip=self.clip[0], 
			predict_reward=flags.predict_reward, 
			training = self.training
		)
		self.model_list.append(agent)
			
	def sync(self):
		# assert not self.is_global_network(), 'you are trying to sync the global network with itself'
		for i in range(self.model_size):
			agent = self.model_list[i]
			sync = self.sync_list[i]
			agent.sync(sync)
			
	def initialize_gradient_optimizer(self):
		self.global_step = []
		self.learning_rate = []
		self.clip = []
		self.gradient_optimizer = []
		for i in range(self.model_size):
		# global step
			self.global_step.append( tf.Variable(0, trainable=False) )
		# learning rate
			self.learning_rate.append( eval('tf.train.'+flags.alpha_annealing_function)(learning_rate=flags.alpha, global_step=self.global_step[i], decay_steps=flags.alpha_decay_steps, decay_rate=flags.alpha_decay_rate) if flags.alpha_decay else flags.alpha )
		# clip
			self.clip.append( eval('tf.train.'+flags.clip_annealing_function)(learning_rate=flags.clip, global_step=self.global_step[i], decay_steps=flags.clip_decay_steps, decay_rate=flags.clip_decay_rate) if flags.clip_decay else flags.clip )
		# gradient optimizer
			self.gradient_optimizer.append( eval('tf.train.'+flags.optimizer+'Optimizer')(learning_rate=self.learning_rate[i], use_locking=True) )
			
	def bind_to_global(self, global_network):
		self.sync_list = []
		for i in range(self.model_size):
			local_agent = self.get_model(i)
			global_agent = global_network.get_model(i)
			local_agent.minimize_local_loss(optimizer=global_network.gradient_optimizer[i], global_step=global_network.global_step[i], global_var_list=global_agent.get_shared_keys())
			self.sync_list.append(local_agent.bind_sync(global_agent)) # for syncing local network with global one

	def get_model(self, id):
		return self.model_list[id]
		
	def get_statistics(self):
		stats = {}
		if self.training:
			# build loss statistics
			if flags.print_loss:
				for i in range(self.model_size):
					for key, value in self._loss_list[i].items():
						stats['loss_{}{}_avg'.format(key,i)] = np.average(value)
		# build models usage statistics
		if self.model_size > 1:
			total_usage = 0
			usage_matrix = {}
			for u in self._model_usage_list:
				if not (u in usage_matrix):
					usage_matrix[u] = 0
				usage_matrix[u] += 1
				total_usage += 1
			for i in range(self.model_size):
				stats['model_{}'.format(i)] = 0
			for key, value in usage_matrix.items():
				stats['model_{}'.format(key)] = value/total_usage if total_usage != 0 else 0
		return stats
		
	def add_to_statistics(self, id):
		self._model_usage_list.append(id)
		if len(self._model_usage_list) > flags.match_count_for_evaluation:
			self._model_usage_list.popleft() # remove old statistics
		
	def get_shared_keys(self):
		vars = []
		for agent in self.model_list:
			vars += agent.get_shared_keys()
		return vars
		
	def reset(self):
		self.step = 0
		self.agent_id = 0
		# Internal states
		self.internal_states = None if flags.share_internal_state else [None]*self.model_size
		if self.training:
			# Count based exploration
			if flags.use_count_based_exploration_reward:
				self.hash_state_table = {}
			
	def initialize_new_batch(self):
		self.batch = ExperienceBatch(self.model_size)
		
	def estimate_value(self, agent_id, states, concats=None, internal_state=None):
		return self.get_model(agent_id).predict_value(states=states, concats=concats, internal_state=internal_state)
		
	def act(self, act_function, state, concat=None):
		agent_id = self.agent_id
		agent = self.get_model(agent_id)
		
		internal_state = self.internal_states if flags.share_internal_state else self.internal_states[agent_id]
		action_batch, value_batch, policy_batch, new_internal_state = agent.predict_action(states=[state], concats=[concat], internal_state=internal_state)
		if flags.share_internal_state:
			self.internal_states = new_internal_state
		else:
			self.internal_states[agent_id] = new_internal_state
		
		action, value, policy = action_batch[0], value_batch[0], policy_batch[0]
		new_state, extrinsic_reward, terminal = act_function(action)
		if self.training:
			if flags.clip_reward:
				extrinsic_reward = np.clip(extrinsic_reward, flags.min_reward, flags.max_reward)
			
		intrinsic_reward = 0
		if self.training:
			if flags.use_count_based_exploration_reward: # intrinsic reward
				intrinsic_reward += self.get_count_based_exploration_reward(new_state)

		total_reward = np.array([extrinsic_reward, intrinsic_reward], dtype=np.float32)
		if self.training:
			self.batch.add_action(agent_id=agent_id, state=state, concat=concat, action=action, policy=policy, reward=total_reward, value=value, internal_state=internal_state)
		# update step at the end of the action
		self.step += 1
		# return result
		return new_state, value, action, total_reward, terminal, policy

	def get_count_based_exploration_reward(self, new_state):
		if len(self.projection_dataset) < flags.projection_dataset_size:
			self.projection_dataset.append(new_state.flatten())
		if len(self.projection_dataset) == flags.projection_dataset_size:
			if self.projection is None:
				self.projection = SparseRandomProjection(n_components=flags.exploration_hash_size if flags.exploration_hash_size > 0 else 'auto') # http://scikit-learn.org/stable/modules/random_projection.html
			self.projection.fit(self.projection_dataset)
			self.projection_dataset = [] # reset
		if self.projection is not None:
			state_projection = self.projection.transform([new_state.flatten()])[0] # project to smaller dimension
			state_hash = ''.join('1' if x > 0 else '0' for x in state_projection) # build binary locality-sensitive hash
			if state_hash not in self.hash_state_table:
				self.hash_state_table[state_hash] = 1
			else:
				self.hash_state_table[state_hash] += 1
			exploration_bonus = 2/np.sqrt(self.hash_state_table[state_hash]) - 1 # in [-1,1]
			return flags.positive_exploration_coefficient*exploration_bonus if exploration_bonus > 0 else flags.negative_exploration_coefficient*exploration_bonus
		return 0
			
	def compute_discounted_cumulative_reward(self, batch):
		last_value = batch.bootstrap['value'] if 'value' in batch.bootstrap else 0.
		batch.compute_discounted_cumulative_reward(agents=self.agents_set, last_value=last_value, gamma=flags.gamma, lambd=flags.lambd)
		return batch
		
	def train(self, batch):
		# assert self.global_network is not None, 'Cannot train the global network.'
		states = batch.states
		internal_states = batch.internal_states
		concats = batch.concats
		actions = batch.actions
		policies = batch.policies
		values = batch.values
		rewards = batch.rewards
		dcr = batch.discounted_cumulative_rewards
		gae = batch.generalized_advantage_estimators
		batch_error = []
		for i in range(self.model_size):
			batch_size = len(states[i])
			if batch_size > 0:
				model = self.get_model(i)
				# reward prediction
				if model.predict_reward:
					sampled_batch = self.reward_prediction_buffer.sample()
					reward_prediction_states, reward_prediction_target = self.get_reward_prediction_tuple(sampled_batch)
				else:
					reward_prediction_states = None
					reward_prediction_target = None
				# train
				error, train_info = model.train(
					states=states[i], concats=concats[i],
					actions=actions[i], values=values[i],
					policies=policies[i],
					rewards=rewards[i],
					discounted_cumulative_rewards=dcr[i],
					generalized_advantage_estimators=gae[i],
					reward_prediction_states=reward_prediction_states,
					reward_prediction_target=reward_prediction_target,
					internal_state=internal_states[i][0]
				)
				batch_error.append(error)
				# loss statistics
				if flags.print_loss:
					for key, value in train_info.items():
						if key not in self._loss_list[i]:
							self._loss_list[i][key] = deque()
						self._loss_list[i][key].append(value)
						if len(self._loss_list[i][key]) > flags.match_count_for_evaluation: # remove old statistics
							self._loss_list[i][key].popleft()
		return batch_error
		
	def bootstrap(self, state, concat=None):
		agent_id = self.agent_id
		internal_state = self.internal_states if flags.share_internal_state else self.internal_states[agent_id]
		value_batch, _ = self.estimate_value(agent_id=agent_id, states=[state], concats=[concat], internal_state=internal_state)
		bootstrap = self.batch.bootstrap
		bootstrap['internal_state'] = internal_state
		bootstrap['agent_id'] = agent_id
		bootstrap['state'] = state
		bootstrap['concat'] = concat
		bootstrap['value'] = value_batch[0]
		
	def replay_value(self, batch): # replay values
		# replay values
		for (agent_id,pos) in batch.step_generator():
			concat, state, internal_state = batch.get_action(['concats','states','internal_states'], agent_id, pos)
			value_batch, _ = self.estimate_value(agent_id=agent_id, states=[state], concats=[concat], internal_state=internal_state)
			batch.set_action({'values':value_batch[0]}, agent_id, pos)
		if 'value' in batch.bootstrap:
			bootstrap = batch.bootstrap
			agent_id = bootstrap['agent_id']
			value_batch, _ = self.estimate_value(agent_id=agent_id, states=[bootstrap['state']], concats=[bootstrap['concat']], internal_state=bootstrap['internal_state'])
			bootstrap['value'] = value_batch[0]
		return self.compute_discounted_cumulative_reward(batch)
		
	def add_to_reward_prediction_buffer(self, batch):
		batch_size = batch.get_size(self.agents_set)
		if batch_size < 2:
			return
		batch_extrinsic_reward = batch.get_cumulative_reward(self.agents_set)[0]
		self.reward_prediction_buffer.put(batch=batch, type_id=1 if batch_extrinsic_reward != 0 else 0) # process batch only after sampling, for better perfomance
			
	def get_reward_prediction_tuple(self, batch):
		flat_states = [batch.get_action('states', agent_id, pos) for (agent_id,pos) in batch.step_generator(self.agents_set)]
		flat_rewards = [batch.get_action('rewards', agent_id, pos) for (agent_id,pos) in batch.step_generator(self.agents_set)]
		states_count = len(flat_states)
		length = min(3, states_count-1)
		start_idx = np.random.randint(states_count-length) if states_count > length else 0
		reward_prediction_states = [flat_states[start_idx+i] for i in range(length)]
		reward_prediction_target = np.zeros((1,3))
		
		target_reward = flat_rewards[start_idx+length][0] # use only extrinsic rewards
		if target_reward == 0:
			reward_prediction_target[0][0] = 1.0 # zero
		elif target_reward > 0:
			reward_prediction_target[0][1] = 1.0 # positive
		else:
			reward_prediction_target[0][2] = 1.0 # negative
		return reward_prediction_states, reward_prediction_target
			
	def add_to_replay_buffer(self, batch, batch_error):
		batch_size = batch.get_size(self.agents_set)
		if batch_size < 1:
			return
		batch_reward = batch.get_cumulative_reward(self.agents_set)
		batch_extrinsic_reward = batch_reward[0]
		batch_intrinsic_reward = batch_reward[1]
		batch_tot_reward = batch_extrinsic_reward + batch_intrinsic_reward
		if batch_tot_reward == 0 and flags.save_only_batches_with_reward:
			return
		if flags.replay_using_default_internal_state:
			batch.reset_internal_states()
		type_id = (1 if batch_intrinsic_reward > 0 else (2 if batch_extrinsic_reward > 0 else 0))
		if flags.prioritized_replay:
			self.experience_buffer.put(batch=batch, priority=batch_tot_reward, type_id=type_id)
		else:
			self.experience_buffer.put(batch=batch, type_id=type_id)
		
	def replay_experience(self):
		if not self.experience_buffer.has_atleast(flags.replay_start):
			return
		n = np.random.poisson(flags.replay_ratio)
		for _ in range(n):
			old_batch = self.experience_buffer.sample()
			self.train(self.replay_value(old_batch) if flags.replay_value else old_batch)
		
	def process_batch(self, global_step):
		batch = self.compute_discounted_cumulative_reward(self.batch)
		# reward prediction
		if flags.predict_reward:
			self.add_to_reward_prediction_buffer(batch) # do it before training, this way there will be at least one batch in the reward_prediction_buffer
			if self.reward_prediction_buffer.is_empty():
				return # cannot train without reward prediction, wait until reward_prediction_buffer is not empty
		# train
		batch_error = self.train(batch)
		# experience replay (after training!)
		if flags.replay_ratio > 0 and global_step > flags.replay_step:
			self.replay_experience()
			self.add_to_replay_buffer(batch, batch_error)
class NetworkManager(object):
	# Experience replay
	if flags.replay_mean > 0:
		# Use locking because buffers are shared among threads
		experience_buffer_lock = Lock()
		if flags.prioritized_replay:
			experience_buffer = PrioritizedBuffer(
				size=flags.replay_buffer_size, 
				alpha=flags.prioritized_replay_alpha, 
				prioritized_drop_probability=flags.prioritized_drop_probability
			)
		else:
			experience_buffer = Buffer(size=flags.replay_buffer_size)
		ImportantInformation(experience_buffer, 'experience_buffer')
	
	def __init__(self, group_id, environment_info, global_network=None, training=True):
		self.training = training
		self.group_id = group_id
		self.set_model_size()
		self.global_network = global_network
		# Build agents
		self.algorithm = eval('{}_Algorithm'.format(flags.algorithm))
		self.model_list = self.build_agents(algorithm=self.algorithm, environment_info=environment_info)
		# Build global_step and gradient_optimizer
		if self.is_global_network():
			self.global_step, self.gradient_optimizer = self.build_gradient_optimizer()
		else:
			self.global_step = self.global_network.global_step
		# Prepare loss
		self.prepare_loss(self.global_step)
		if self.training:
			if not self.is_global_network():
				self.minimize_local_loss(self.global_network)
		# Bind optimizer to global
		if not self.is_global_network():
			self.bind_to_global(self.global_network)
		# Intrinsic reward
		if flags.intrinsic_reward and flags.scale_intrinsic_reward:
			self.intrinsic_reward_scaler = [RunningMeanStd() for _ in range(self.model_size)]
			ImportantInformation(self.intrinsic_reward_scaler, 'intrinsic_reward_scaler{}'.format(self.group_id))
		# Reward manipulators
		self.intrinsic_reward_manipulator = eval(flags.intrinsic_reward_manipulator)
		self.intrinsic_reward_mini_batch_size = int(flags.batch_size*flags.intrinsic_rewards_mini_batch_fraction)
		if flags.intrinsic_reward:
			print('[Group{}] Intrinsic rewards mini-batch size: {}'.format(self.group_id, self.intrinsic_reward_mini_batch_size))

	def get_statistics(self):
		stats = {}
		for model in self.model_list:
			stats.update(model.get_statistics())
		return stats
			
	def is_global_network(self):
		return self.global_network is None
			
	def set_model_size(self):
		self.model_size = 1
		self.agents_set = (0,)
			
	def build_agents(self, algorithm, environment_info):
		model_list = []
		agent=algorithm(
			group_id=self.group_id,
			model_id=0,
			environment_info=environment_info, 
			training=self.training
		)
		model_list.append(agent)
		return model_list
		
	def sync(self):
		assert not self.is_global_network(), 'Trying to sync the global network with itself'
		# Synchronize models
		for model in self.model_list:
			model.sync()
					
	def build_gradient_optimizer(self):
		global_step = [None]*self.model_size
		gradient_optimizer = [None]*self.model_size
		for i in range(self.model_size):
			gradient_optimizer[i], global_step[i] = self.model_list[i].build_optimizer(flags.optimizer)
		return global_step, gradient_optimizer
	
	def minimize_local_loss(self, global_network):
		for i,(local_agent,global_agent) in enumerate(zip(self.model_list, global_network.model_list)):
			local_agent.minimize_local_loss(optimizer=global_network.gradient_optimizer[i], global_step=global_network.global_step[i], global_agent=global_agent)
			
	def prepare_loss(self, global_step):
		for i,model in enumerate(self.model_list):
			model.prepare_loss(global_step[i])
			
	def bind_to_global(self, global_network):
		# for synching local network with global one
		for local_agent,global_agent in zip(self.model_list, global_network.model_list):
			local_agent.bind_sync(global_agent)

	def get_model(self, id=0):
		return self.model_list[id]
		
	def predict_action(self, states, internal_states):
		action_dict = {
			'states': states,
			'internal_states': internal_states,
			'sizes': [1 for _ in range(len(states))] # states are from different environments with different internal states
		}
		actions, hot_actions, policies, values, new_internal_states = self.get_model().predict_action(action_dict)
		agents = [0]*len(actions)
		return actions, hot_actions, policies, values, new_internal_states, agents
	
	def _play_critic(self, batch, with_value=True, with_bootstrap=True, with_intrinsic_reward=True):
		# Compute values and bootstrap
		if with_value:
			for agent_id in range(self.model_size):
				value_dict = {
					'states': batch.states[agent_id],
					'actions': batch.actions[agent_id],
					'policies': batch.policies[agent_id],
					'internal_states': [ batch.internal_states[agent_id][0] ], # a single internal state
					'bootstrap': [ {'state':batch.new_states[agent_id][-1]} ],
					'sizes': [ len(batch.states[agent_id]) ] # playing critic on one single batch
				}
				value_batch, bootstrap_value, extra_batch = self.get_model(agent_id).predict_value(value_dict)
				if extra_batch is not None:
					batch.extras[agent_id] = list(extra_batch)
				batch.values[agent_id] = list(value_batch)
				batch.bootstrap[agent_id] = bootstrap_value
				assert len(batch.states[agent_id]) == len(batch.values[agent_id]), "Number of values does not match the number of states"
		elif with_bootstrap:
			self._bootstrap(batch)
		if with_intrinsic_reward:
			self._compute_intrinsic_rewards(batch)
		if with_value or with_intrinsic_reward or with_bootstrap:
			self._compute_discounted_cumulative_reward(batch)
			
	def _bootstrap(self, batch):
		for agent_id in range(self.model_size):
			_, _, _, (bootstrap_value,), _, _ = self.predict_action(
				states=batch.new_states[agent_id][-1:],  
				internal_states=batch.new_internal_states[agent_id][-1:]
			)
			batch.bootstrap[agent_id] = bootstrap_value

	def _compute_intrinsic_rewards(self, batch):
		for agent_id in range(self.model_size):
			# Get actual rewards
			rewards = batch.rewards[agent_id]
			manipulated_rewards = batch.manipulated_rewards[agent_id]
			# Predict intrinsic rewards
			reward_dict = {
				'states': batch.new_states[agent_id],
				'state_mean': self.state_mean,
				'state_std':self.state_std
			}
			intrinsic_rewards = self.get_model(agent_id).predict_reward(reward_dict)
			# Scale intrinsic rewards
			if flags.scale_intrinsic_reward:
				scaler = self.intrinsic_reward_scaler[agent_id]
				# Build intrinsic_reward scaler
				scaler.update(intrinsic_rewards)
				# If the reward scaler is initialized, we can compute the intrinsic reward
				if not scaler.initialized:
					continue
			# Add intrinsic rewards to batch
			if self.intrinsic_reward_mini_batch_size > 1: 
				# Keep only best intrinsic rewards
				for i in range(0, len(intrinsic_rewards), self.intrinsic_reward_mini_batch_size):
					best_intrinsic_reward_index = i+np.argmax(intrinsic_rewards[i:i+self.intrinsic_reward_mini_batch_size])
					best_intrinsic_reward = intrinsic_rewards[best_intrinsic_reward_index]
					# print(i, best_intrinsic_reward_index, best_intrinsic_reward)
					if flags.scale_intrinsic_reward:
						best_intrinsic_reward = best_intrinsic_reward/scaler.std
					rewards[best_intrinsic_reward_index][1] = best_intrinsic_reward
					manipulated_rewards[best_intrinsic_reward_index][1] = self.intrinsic_reward_manipulator(best_intrinsic_reward)
				# print(best_intrinsic_reward_index,best_intrinsic_reward)
			else: 
				# Keep all intrinsic rewards
				if flags.scale_intrinsic_reward:
					intrinsic_rewards = intrinsic_rewards/scaler.std
				manipulated_intrinsic_rewards = self.intrinsic_reward_manipulator(intrinsic_rewards)
				for i in range(len(intrinsic_rewards)):
					rewards[i][1] = intrinsic_rewards[i]
					manipulated_rewards[i][1] = manipulated_intrinsic_rewards[i]				
					
	def _compute_discounted_cumulative_reward(self, batch):
		batch.compute_discounted_cumulative_reward(
			agents=self.agents_set, 
			gamma=flags.gamma, 
			cumulative_return_builder=self.algorithm.get_reversed_cumulative_return
		)
		
	def _train(self, batch, replay=False, start=None, end=None):
		assert self.global_network is not None, 'Cannot directly _train the global network.'
		# Train every model
		for i,model in enumerate(self.model_list):
			batch_size = len(batch.states[i])
			# Ignore empty batches
			if batch_size == 0:
				continue
			# Check whether to slice the batch
			is_valid_start = start is not None and start != 0 and start > -batch_size
			is_valid_end = end is not None and end != 0 and end < batch_size
			do_slice = is_valid_start or is_valid_end
			if do_slice:
				if not is_valid_start:
					start = None
				if not is_valid_end:
					end = None
			# Build _train dictionary
			train_dict = {
				'states':batch.states[i][start:end] if do_slice else batch.states[i],
				'actions':batch.actions[i][start:end] if do_slice else batch.actions[i],
				'action_masks':batch.action_masks[i][start:end] if do_slice else batch.action_masks[i],
				'values':batch.values[i][start:end] if do_slice else batch.values[i],
				'policies':batch.policies[i][start:end] if do_slice else batch.policies[i],
				'cumulative_returns':batch.cumulative_returns[i][start:end] if do_slice else batch.cumulative_returns[i],
				'internal_state':batch.internal_states[i][start] if is_valid_start else batch.internal_states[i][0],
				'state_mean':self.state_mean,
				'state_std':self.state_std,
			}
			if not flags.runtime_advantage:
				train_dict['advantages'] = batch.advantages[i][start:end] if do_slice else batch.advantages[i]
			# Prepare _train
			train_result = model.prepare_train(train_dict=train_dict, replay=replay)
		
	def _add_to_replay_buffer(self, batch, is_best):
		# Check whether batch is empty
		if batch.is_empty(self.agents_set):
			return False
		# Build batch type
		batch_extrinsic_reward, batch_intrinsic_reward = batch.get_cumulative_reward(self.agents_set)
		#=======================================================================
		# if batch_extrinsic_reward > 0:
		# 	print("Adding new batch with reward: extrinsic {}, intrinsic {}".format(batch_extrinsic_reward, batch_intrinsic_reward))
		#=======================================================================
		type_id = '1' if batch_extrinsic_reward > 0 else '0'
		type_id += '1' if is_best else '0'
		# Populate buffer
		if flags.prioritized_replay:
			priority = batch_intrinsic_reward if flags.intrinsic_reward else batch_extrinsic_reward
			with self.experience_buffer_lock:
				self.experience_buffer.put(batch=batch, priority=priority, type_id=type_id)
		else:
			with self.experience_buffer_lock:
				self.experience_buffer.put(batch=batch, type_id=type_id)
		return True

	def try_to_replay_experience(self):
		if flags.replay_mean <= 0:
			return
		# Check whether experience buffer has enough elements for replaying
		if not self.experience_buffer.has_atleast(flags.replay_start):
			return
		prioritized_replay_with_update = flags.prioritized_replay and flags.intrinsic_reward
		if prioritized_replay_with_update:
			batch_to_update = []
		# Sample n batches from experience buffer
		n = np.random.poisson(flags.replay_mean)
		for _ in range(n):
			# Sample batch
			if prioritized_replay_with_update:
				with self.experience_buffer_lock:
					keyed_sample = self.experience_buffer.keyed_sample()
				batch_to_update.append(keyed_sample)
				old_batch, _, _ = keyed_sample
			else:
				with self.experience_buffer_lock:
					old_batch = self.experience_buffer.sample()
			# Replay value, without keeping experience_buffer_lock the buffer update might be not consistent anymore
			self._play_critic(batch=old_batch, with_value=flags.recompute_value_when_replaying, with_bootstrap=False, with_intrinsic_reward=flags.intrinsic_reward)
			# Train
			self._train(replay=True, batch=old_batch)
		# Update buffer
		if prioritized_replay_with_update:
			for batch, id, type in batch_to_update:
				_, batch_intrinsic_reward = batch.get_cumulative_reward(self.agents_set)
				with self.experience_buffer_lock:
					self.experience_buffer.update_priority(id, batch_intrinsic_reward, type)
		
	def finalize_batch(self, composite_batch, global_step):	
		batch = composite_batch.get()[-1]	
		# Decide whether to compute intrinsic reward
		with_intrinsic_reward = flags.intrinsic_reward and global_step > flags.intrinsic_reward_step
		self._play_critic(batch, with_value=False, with_bootstrap=True, with_intrinsic_reward=with_intrinsic_reward)
		# Train
		self._train(replay=False, batch=batch)
		# Populate replay buffer
		if flags.replay_mean > 0:
			# Check whether to save the whole episode list into the replay buffer
			extrinsic_reward, _ = batch.get_cumulative_reward(self.agents_set)
			is_best = extrinsic_reward > 0 # Best batches = batches that lead to positive extrinsic reward
			# Build the best known cumulative return
			#===================================================================
			# if is_best and not flags.recompute_value_when_replaying:
			# 	if composite_batch.size() > 1: # No need to recompute the cumulative return if composite batch has only 1 batch
			# 		self._compute_discounted_cumulative_reward(composite_batch)
			#===================================================================
			# Add to experience buffer if is good batch or batch has terminated
			add_composite_batch_to_buffer = is_best or (not flags.replay_only_best_batches and batch.terminal)
			if add_composite_batch_to_buffer:
				for old_batch in composite_batch.get():
					self._add_to_replay_buffer(batch=old_batch, is_best=is_best)
				# Clear composite batch
				composite_batch.clear()
示例#26
0
class TcpConnection(object):
    def __init__(self, sock: socket.socket, codec: Codec, processor):
        self._socket = sock
        self._codec = codec
        self._processor = processor
        self._buffer = Buffer()
        self._stop = False
        self._queue = Queue()
        self._last_active_time = time.time()
        self.address = self._socket.getpeername()

    @property
    def last_active_time(self):
        return self._last_active_time

    def _recv(self):
        while not self._stop:
            msg = self._codec.decode(self._buffer, self)
            if msg is None:
                try:
                    data = self._socket.recv(8 * 1024)
                    if data is None or len(data) == 0:
                        break
                except Exception as e:
                    logger.error("Connection:%s, recv message, exception:%s", self.address, e)
                    break
                self._buffer.shrink()
                self._buffer.append(data)
                continue
            self._processor(self, msg)
            self._last_active_time = time.time()
        gevent.spawn(lambda: self.close())
        pass

    def _send(self):
        while not self._stop:
            try:
                data = self._queue.get()
                if data is not None:
                    self._socket.sendall(data)
                else:
                    break
            except Exception as e:
                logger.error("Connection:%s, send message, exception:%s", self.address, e)
                break
        pass

    def run(self):
        gevent.spawn(lambda : self._send())
        self._recv()

    def close(self):
        if self._socket is None:
            return
        self._stop = True
        self._queue.put(None)
        self._socket.close()
        self._socket = None
        logger.info("Connection:%s, close", self.address)
        pass

    def send_message(self, m):
        array = self._codec.encode(m, self)
        self._queue.put(array)
        pass
示例#27
0
class Gem(ContinualModel):
    NAME = 'gem'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il']

    def __init__(self, backbone, loss, args, transform):
        super(Gem, self).__init__(backbone, loss, args, transform)
        self.current_task = 0
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.transform = transform

        # Allocate temporary synaptic memory
        self.grad_dims = []
        for pp in self.parameters():
            self.grad_dims.append(pp.data.numel())

        self.grads_cs = []
        self.grads_da = torch.zeros(np.sum(self.grad_dims)).to(self.device)
        self.transform = transform if self.args.iba else None

    def end_task(self, dataset):
        self.current_task += 1
        self.grads_cs.append(
            torch.zeros(np.sum(self.grad_dims)).to(self.device))

        # add data to the buffer
        samples_per_task = self.args.buffer_size // dataset.N_TASKS

        loader = dataset.not_aug_dataloader(self.args, samples_per_task)
        cur_x, cur_y = next(iter(loader))[:2]
        self.buffer.add_data(
            examples=cur_x.to(self.device),
            labels=cur_y.to(self.device),
            task_labels=torch.ones(samples_per_task, dtype=torch.long).to(
                self.device) * (self.current_task - 1))

    def observe(self, inputs, labels, not_aug_inputs):

        if not self.buffer.is_empty():
            buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data(
                self.args.buffer_size, transform=self.transform)

            for tt in buf_task_labels.unique():
                # compute gradient on the memory buffer
                self.opt.zero_grad()
                cur_task_inputs = buf_inputs[buf_task_labels == tt]
                cur_task_labels = buf_labels[buf_task_labels == tt]

                for i in range(
                        math.ceil(len(cur_task_inputs) /
                                  self.args.batch_size)):
                    cur_task_outputs = self.forward(
                        cur_task_inputs[i * self.args.batch_size:(i + 1) *
                                        self.args.batch_size])
                    penalty = self.loss(
                        cur_task_outputs,
                        cur_task_labels[i * self.args.batch_size:(i + 1) *
                                        self.args.batch_size],
                        reduction='sum') / cur_task_inputs.shape[0]
                    penalty.backward()
                store_grad(self.parameters, self.grads_cs[tt], self.grad_dims)

                # cur_task_outputs = self.forward(cur_task_inputs)
                # penalty = self.loss(cur_task_outputs, cur_task_labels)
                # penalty.backward()
                # store_grad(self.parameters, self.grads_cs[tt], self.grad_dims)

        # now compute the grad on the current data
        self.opt.zero_grad()
        outputs = self.forward(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()

        # check if gradient violates buffer constraints
        if not self.buffer.is_empty():
            # copy gradient
            store_grad(self.parameters, self.grads_da, self.grad_dims)

            dot_prod = torch.mm(self.grads_da.unsqueeze(0),
                                torch.stack(self.grads_cs).T)
            if (dot_prod < 0).sum() != 0:
                project2cone2(self.grads_da.unsqueeze(1),
                              torch.stack(self.grads_cs).T,
                              margin=self.args.gamma)
                # copy gradients back
                overwrite_grad(self.parameters, self.grads_da, self.grad_dims)

        self.opt.step()

        return loss.item()
示例#28
0
    def fill_buffer(self, mem_buffer: Buffer, dataset, t_idx: int) -> None:
        """
        Adds examples from the current task to the memory buffer
        by means of the herding strategy.
        :param mem_buffer: the memory buffer
        :param dataset: the dataset from which take the examples
        :param t_idx: the task index
        """

        mode = self.net.training
        self.net.eval()
        samples_per_class = mem_buffer.buffer_size // len(self.classes_so_far)

        if t_idx > 0:
            # 1) First, subsample prior classes
            buf_x, buf_y, buf_l = self.buffer.get_all_data()

            mem_buffer.empty()
            for _y in buf_y.unique():
                idx = (buf_y == _y)
                _y_x, _y_y, _y_l = buf_x[idx], buf_y[idx], buf_l[idx]
                mem_buffer.add_data(examples=_y_x[:samples_per_class],
                                    labels=_y_y[:samples_per_class],
                                    logits=_y_l[:samples_per_class])

        # 2) Then, fill with current tasks
        loader = dataset.not_aug_dataloader(self.args, self.args.batch_size)

        # 2.1 Extract all features
        a_x, a_y, a_f, a_l = [], [], [], []
        for x, y, not_norm_x in loader:
            x, y, not_norm_x = (a.to(self.device) for a in [x, y, not_norm_x])
            a_x.append(not_norm_x.to('cpu'))
            a_y.append(y.to('cpu'))

            feats = self.net.features(x)
            a_f.append(feats.cpu())
            a_l.append(torch.sigmoid(self.net.classifier(feats)).cpu())
        a_x, a_y, a_f, a_l = torch.cat(a_x), torch.cat(a_y), torch.cat(
            a_f), torch.cat(a_l)

        # 2.2 Compute class means
        for _y in a_y.unique():
            idx = (a_y == _y)
            _x, _y, _l = a_x[idx], a_y[idx], a_l[idx]
            feats = a_f[idx]
            mean_feat = feats.mean(0, keepdim=True)

            running_sum = torch.zeros_like(mean_feat)
            i = 0
            while i < samples_per_class and i < feats.shape[0]:
                cost = (mean_feat - (feats + running_sum) / (i + 1)).norm(2, 1)

                idx_min = cost.argmin().item()

                mem_buffer.add_data(
                    examples=_x[idx_min:idx_min + 1].to(self.device),
                    labels=_y[idx_min:idx_min + 1].to(self.device),
                    logits=_l[idx_min:idx_min + 1].to(self.device))

                running_sum += feats[idx_min:idx_min + 1]
                feats[idx_min] = feats[idx_min] + 1e6
                i += 1

        assert len(mem_buffer.examples) <= mem_buffer.buffer_size

        self.net.train(mode)
示例#29
0
 def __init__(self, backbone, loss, args, transform):
     super(Der, self).__init__(backbone, loss, args, transform)
     self.buffer = Buffer(self.args.buffer_size, self.device)
示例#30
0
文件: fdr.py 项目: yxue3357/mammoth
class Fdr(ContinualModel):
    NAME = 'fdr'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        super(Fdr, self).__init__(backbone, loss, args, transform)
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.current_task = 0
        self.i = 0
        self.soft = torch.nn.Softmax(dim=1)
        self.logsoft = torch.nn.LogSoftmax(dim=1)

    def end_task(self, dataset):
        self.current_task += 1
        examples_per_task = self.args.buffer_size // self.current_task

        if self.current_task > 1:
            buf_x, buf_log, buf_tl = self.buffer.get_all_data()
            self.buffer.empty()

            for ttl in buf_tl.unique():
                idx = (buf_tl == ttl)
                ex, log, tasklab = buf_x[idx], buf_log[idx], buf_tl[idx]
                first = min(ex.shape[0], examples_per_task)
                self.buffer.add_data(examples=ex[:first],
                                     logits=log[:first],
                                     task_labels=tasklab[:first])
        counter = 0
        with torch.no_grad():
            for i, data in enumerate(dataset.train_loader):
                inputs, labels, not_aug_inputs = data
                inputs = inputs.to(self.device)
                not_aug_inputs = not_aug_inputs.to(self.device)
                outputs = self.net(inputs)
                if examples_per_task - counter < 0:
                    break
                self.buffer.add_data(
                    examples=not_aug_inputs[:(examples_per_task - counter)],
                    logits=outputs.data[:(examples_per_task - counter)],
                    task_labels=(torch.ones(self.args.batch_size) *
                                 (self.current_task - 1))[:(examples_per_task -
                                                            counter)])
                counter += self.args.batch_size

    def observe(self, inputs, labels, not_aug_inputs):
        self.i += 1

        self.opt.zero_grad()
        outputs = self.net(inputs)
        loss = self.loss(outputs, labels)
        loss.backward()
        self.opt.step()
        if not self.buffer.is_empty():
            self.opt.zero_grad()
            buf_inputs, buf_logits, _ = self.buffer.get_data(
                self.args.minibatch_size, transform=self.transform)
            buf_outputs = self.net(buf_inputs)
            loss = torch.norm(
                self.soft(buf_outputs) - self.soft(buf_logits), 2, 1).mean()
            assert not torch.isnan(loss)
            loss.backward()
            self.opt.step()

        return loss.item()