Example #1
0
    def run(self, train_tasks, valid_tasks, test_tasks, input_shape, device):

        # Create model
        if dataset == "omni":
            features = ConvBase(output_size=64, hidden=32, channels=1, max_pool=False)
        else:
            features = ConvBase(output_size=64, channels=3, max_pool=True)
        features = torch.nn.Sequential(features, Lambda(lambda x: x.view(-1, fc_neurons)))
        features.to(device)

        head = torch.nn.Linear(fc_neurons, self.params['ways'])
        head = MAML(head, lr=self.params['inner_lr'])
        head.to(device)

        # Setup optimization
        all_parameters = list(features.parameters()) + list(head.parameters())
        optimizer = torch.optim.Adam(all_parameters, lr=self.params['outer_lr'])
        loss = torch.nn.CrossEntropyLoss(reduction='mean')

        self.log_model(features, device, input_shape=input_shape, name='features')  # Input shape is specific to dataset
        head_input_shape = (self.params['ways'], fc_neurons)
        self.log_model(head, device, input_shape=head_input_shape, name='head')  # Input shape is specific to dataset

        t = trange(self.params['num_iterations'])
        try:
            for iteration in t:
                optimizer.zero_grad()
                meta_train_loss = 0.0
                meta_train_accuracy = 0.0
                meta_valid_loss = 0.0
                meta_valid_accuracy = 0.0
                for task in range(self.params['meta_batch_size']):
                    # Compute meta-training loss
                    learner = head.clone()
                    batch = train_tasks.sample()
                    eval_loss, eval_acc = fast_adapt(batch, learner, loss,
                                                     self.params['adapt_steps'],
                                                     self.params['shots'], self.params['ways'],
                                                     device, features=features)
                    eval_loss.backward()
                    meta_train_loss += eval_loss.item()
                    meta_train_accuracy += eval_acc.item()

                    # Compute meta-validation loss
                    learner = head.clone()
                    batch = valid_tasks.sample()
                    eval_loss, eval_acc = fast_adapt(batch, learner, loss,
                                                     self.params['adapt_steps'],
                                                     self.params['shots'], self.params['ways'],
                                                     device, features=features)
                    meta_valid_loss += eval_loss.item()
                    meta_valid_accuracy += eval_acc.item()

                meta_train_loss = meta_train_loss / self.params['meta_batch_size']
                meta_valid_loss = meta_valid_loss / self.params['meta_batch_size']
                meta_train_accuracy = meta_train_accuracy / self.params['meta_batch_size']
                meta_valid_accuracy = meta_valid_accuracy / self.params['meta_batch_size']

                metrics = {'train_loss': meta_train_loss,
                           'train_acc': meta_train_accuracy,
                           'valid_loss': meta_valid_loss,
                           'valid_acc': meta_valid_accuracy}
                t.set_postfix(metrics)
                self.log_metrics(metrics)

                # Average the accumulated gradients and optimize
                for p in all_parameters:
                    p.grad.data.mul_(1.0 / self.params['meta_batch_size'])
                optimizer.step()

                if iteration % self.params['save_every'] == 0:
                    self.save_model_checkpoint(features, 'features_' + str(iteration + 1))
                    self.save_model_checkpoint(head, 'head_' + str(iteration + 1))

        # Support safely manually interrupt training
        except KeyboardInterrupt:
            print('\nManually stopped training! Start evaluation & saving...\n')
            self.logger['manually_stopped'] = True
            self.params['num_iterations'] = iteration

        self.save_model(features, name='features')
        self.save_model(head, name='head')

        self.logger['elapsed_time'] = str(round(t.format_dict['elapsed'], 2)) + ' sec'
        # Meta-testing on unseen tasks
        self.logger['test_acc'] = evaluate(self.params, test_tasks, head, loss, device, features=features)
        self.log_metrics({'test_acc': self.logger['test_acc']})
        self.save_logs_to_file()
Example #2
0
    # ------
    model = AnalogBeamformer(n_antenna = num_antenna, n_beam = N)
    maml = MAML(model, lr=fast_lr, first_order=True)
    # Training:
    # ---------
    optimizer = optim.Adam(model.parameters(),lr=meta_lr, betas=(0.9,0.999), amsgrad=False)
    loss_fn = bf_gain_loss

    for iteration in range(nepoch):
        optimizer.zero_grad()
        meta_train_error = 0.0
        meta_valid_error = 0.0
        for task in range(batch_size):
            dataset.change_cluster()
            # Compute meta-training loss
            learner = maml.clone()
            batch_idc = dataset.sample()
            batch = (h_concat_scaled[batch_idc,:],egc_gain_scaled[batch_idc])
            evaluation_error = fast_adapt_est_h(batch,
                                        learner,
                                        loss_fn,
                                        update_step,
                                        shots,
                                        h_est_force_z)
            evaluation_error.backward()
            meta_train_error += evaluation_error.item()
    
            # Compute meta-validation loss
            learner = maml.clone()
            batch_idc = dataset.sample()
            batch = (h_concat_scaled[batch_idc,:],egc_gain_scaled[batch_idc])
Example #3
0
    def run(self, train_tasks, valid_tasks, test_tasks, model, input_shape, device):

        model.to(device)
        maml = MAML(model, lr=self.params['inner_lr'], first_order=False)
        opt = torch.optim.Adam(maml.parameters(), self.params['outer_lr'])
        loss = torch.nn.CrossEntropyLoss(reduction='mean')

        self.log_model(maml, device, input_shape=input_shape)  # Input shape is specific to dataset

        t = trange(self.params['num_iterations'])
        try:

            for iteration in t:
                # Clear the gradients after successfully back-propagating through the whole network
                opt.zero_grad()
                # Initialize iteration's metrics
                meta_train_loss = 0.0
                meta_train_accuracy = 0.0
                meta_valid_loss = 0.0
                meta_valid_accuracy = 0.0
                # Inner (Adaptation) loop
                for task in range(self.params['meta_batch_size']):
                    # Compute meta-training loss
                    learner = maml.clone()
                    batch = train_tasks.sample()
                    eval_loss, eval_acc = fast_adapt(batch, learner, loss,
                                                     self.params['adapt_steps'],
                                                     self.params['shots'], self.params['ways'],
                                                     device)

                    # Calculate the gradients of the now updated parameters of the model using the evaluation loss!
                    eval_loss.backward()
                    meta_train_loss += eval_loss.item()
                    meta_train_accuracy += eval_acc.item()

                    # Compute meta-validation loss
                    learner = maml.clone()
                    batch = valid_tasks.sample()
                    eval_loss, eval_acc = fast_adapt(batch, learner, loss,
                                                     self.params['adapt_steps'],
                                                     self.params['shots'], self.params['ways'],
                                                     device)
                    meta_valid_loss += eval_loss.item()
                    meta_valid_accuracy += eval_acc.item()

                meta_train_loss = meta_train_loss / self.params['meta_batch_size']
                meta_valid_loss = meta_valid_loss / self.params['meta_batch_size']
                meta_train_accuracy = meta_train_accuracy / self.params['meta_batch_size']
                meta_valid_accuracy = meta_valid_accuracy / self.params['meta_batch_size']

                metrics = {'train_loss': meta_train_loss,
                           'train_acc': meta_train_accuracy,
                           'valid_loss': meta_valid_loss,
                           'valid_acc': meta_valid_accuracy}
                t.set_postfix(metrics)
                self.log_metrics(metrics)

                # Average the accumulated gradients and optimize
                for p in maml.parameters():
                    p.grad.data.mul_(1.0 / self.params['meta_batch_size'])
                opt.step()

                if iteration % self.params['save_every'] == 0:
                    self.save_model_checkpoint(model, str(iteration))

        # Support safely manually interrupt training
        except KeyboardInterrupt:
            print('\nManually stopped training! Start evaluation & saving...\n')
            self.logger['manually_stopped'] = True
            self.params['num_iterations'] = iteration

        self.save_model(model)

        self.logger['elapsed_time'] = str(round(t.format_dict['elapsed'], 2)) + ' sec'
        # Meta-testing on unseen tasks
        self.logger['test_acc'] = evaluate(self.params, test_tasks, maml, loss, device)
        self.log_metrics({'test_acc': self.logger['test_acc']})
        self.save_logs_to_file()