示例#1
0
def train(model, crit, optim, dataset, embedding_sz, device, epoch):
    for i, data in enumerate(dataset):
        img, lbl = data
        img = img.to(device)
        z = torch.empty(img.shape[0], embedding_sz, device=device).normal_()

        # Real images
        optim['D'].zero_grad()
        score_real = Discriminator(img)
        l_real = optim(score_real, torch.ones(img.shape[0]))

        # Fake images
        G_z = Generator(z)
        score_fake = Discriminator(G_z.detach())
        l_fake = optim(score_fake, torch.zeros(img.shape[0]))

        l_disc = 0.5*(l_fake + l_real)
        l_disc.backward()
        optim['D'].step()

        # Generator Training
        optim['G'].zero_grad()
        score_fake = Discriminator(G_z)
        l_gen = optim(score_fake, torch.ones(img.shape[0]))
        l_gen.backward()
        optim['G'].step()
示例#2
0
    def build(self, model_params):
        cfg_as_dict = self.cfg.asdict()

        name = cfg_as_dict.pop("name", self.defaults.name)
        optim = self.optimizers[name]

        return optim(model_params, **cfg_as_dict)
示例#3
0
def loss_on_random_task(initial_model, K, num_steps, optim=torch.optim.SGD):
    """
    trains the model on a random sine task and measures the loss curve.
    
    for each n in num_steps_measured, records the model function after n gradient updates.
    """

    # copy MAML model into a new object to preserve MAML weights during training
    model = nn.Sequential(
        OrderedDict([('l1', nn.Linear(600, 300)), ('relu1', nn.ReLU()),
                     ('l2', nn.Linear(300, 150)), ('relu2', nn.ReLU()),
                     ('l3', nn.Linear(150, 70)), ('relu3', nn.ReLU()),
                     ('l4', nn.Linear(70, 20)), ('relu4', nn.ReLU()),
                     ('l5', nn.Linear(20, 1))]))

    model.load_state_dict(initial_model.state_dict())
    criterion = nn.MSELoss()
    optimiser = optim(model.parameters(), 0.01)

    # train model on a random task

    X, y = tasks.sample_data(K)
    losses = []
    for step in range(1, num_steps + 1):
        loss = criterion(model(X), y) / K
        losses.append(loss.item())

        # compute grad and update inner loop weights
        model.zero_grad()
        loss.backward()
        optimiser.step()

    return losses
示例#4
0
def train_net(net, train_loader, test_loader, only_fc=True,
              optim=optim.Adam,
              loss_fn=nn.CrossEntropyLoss(),
              n_iter=10, device='cpu'):
    net.to(device)

    train_losses = []
    train_acc, val_acc = [], []

    if only_fc:
        # only optimize the fc layer
        optimizer = optim(net.fc.parameters())
    else:
        optimizer = optim(net.parameters())

    for epoch in range(n_iter):
        net.train()
        n = 0
        n_acc = 0
        running_loss = 0.0
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
                                     total=len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)

            out = net(xx)

            loss = loss_fn(out, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            n += len(xx)
            _, y_pred = out.max(1)

            n_acc += (y_pred == yy).float().sum().item()
        train_losses.append(running_loss/i)
        train_acc.append(n_acc / n)

        # test the model
        val_acc.append(eval_net(net, test_loader, device))

        # print the result
        print("epoch : {} \n\t train_losses : {} \n\t train_acc : {} \n\t val_acc: {}".format(
            epoch, train_losses[-1], train_acc[-1], val_acc[-1]), flush=True)
示例#5
0
 def train(self, batches=10, steps=100, optim=optim.SGD, **kwargs):
     """Training the model."""
     optimizer = optim([self._pos], **kwargs)
     for b_idx in range(batches):
         optimizer.zero_grad()
         for _ in range(steps):
             self.loss.backward(retain_graph=True)
             optimizer.step()
         log.info(f"batch: {b_idx + 1}/{batches}\tloss: {self.loss}")
示例#6
0
    def __init__(self,
                 observation_shape,
                 actions_shape,
                 q_lr=.001,
                 p_lr=.001,
                 optim=optim.Adam):

        self.Q_stable = DQN(observation_shape, actions_shape)
        self.Q_unstable = DQN(observation_shape, actions_shape)
        self.Q_stable.load_state_dict(self.Q_unstable.state_dict())
        self.Q_stable.eval()

        self.Policy_stable = PolicyNet(observation_shape, actions_shape)
        self.Policy_unstable = PolicyNet(observation_shape, actions_shape)

        self.Policy_stable.load_state_dict(self.Policy_unstable.state_dict())
        self.Policy_stable.eval()

        self.Q_optim = optim(self.Q_unstable.parameters(), lr=q_lr)
        self.Policy_optim = optim(self.Policy_unstable.parameters(), lr=p_lr)
示例#7
0
def model_functions_at_training(initial_model,
                                X,
                                y,
                                sampled_steps,
                                x_axis,
                                optim=torch.optim.SGD,
                                lr=0.01):
    """
    trains the model on X, y and measures the loss curve.
    
    for each n in sampled_steps, records model(x_axis) after n gradient updates.
    """

    # copy MAML model into a new object to preserve MAML weights during training
    model = nn.Sequential(
        OrderedDict([('l1', nn.Linear(600, 300)), ('relu1', nn.ReLU()),
                     ('l2', nn.Linear(300, 150)), ('relu2', nn.ReLU()),
                     ('l3', nn.Linear(150, 70)), ('relu3', nn.ReLU()),
                     ('l4', nn.Linear(70, 20)), ('relu4', nn.ReLU()),
                     ('l5', nn.Linear(20, 1))]))
    model.load_state_dict(initial_model.state_dict())
    criterion = nn.MSELoss()
    optimiser = optim(model.parameters(), lr)

    # train model on a random task
    num_steps = max(sampled_steps)
    K = X.shape[0]

    losses = []
    outputs = {}
    for step in range(1, num_steps + 1):
        loss = criterion(model(X), y) / K
        losses.append(loss.item())

        # compute grad and update inner loop weights
        model.zero_grad()
        loss.backward()
        optimiser.step()

        # plot the model function
        if step in sampled_steps:
            outputs[step] = model(
                torch.tensor(x_axis,
                             dtype=torch.float).view(-1, 1)).detach().numpy()

    outputs['initial'] = initial_model(
        torch.tensor(x_axis, dtype=torch.float).view(-1, 1)).detach().numpy()

    return outputs, losses
示例#8
0
def maxloss_perturbs(
    model,
    criterion,
    images,
    labels,
    epsilon=0.3,
    optim=optim.SGD,
    optim_params={"lr": 0.03},
    n_epoches=100,
    verbose=True,
    device=None,
):
    model.eval()

    if device is not None:
        model = model.to(device)
        images = images.to(device)
        labels = labels.to(device)

    original_images = images

    perturbs = torch.rand(images.shape, device=images.device) / 1000
    optimizer = optim([perturbs], **optim_params)

    for e in range(n_epoches + 1):
        images = images.detach()
        images.requires_grad = True

        optimizer.zero_grad()
        output = model(images)
        loss = -1 * criterion(output, labels)
        loss.backward()
        optimizer.step()

        if e == n_epoches:
            return torch.clamp(perturbs, -epsilon, epsilon)

        images = images + perturbs
        images = torch.max(
            torch.min(images, original_images + epsilon),
            original_images - epsilon,
        )
示例#9
0
def continuous_optim(tensor_list,
                     train_data,
                     loss_fun,
                     epochs=10,
                     val_data=None,
                     other_args=dict()):
    """
    Train a tensor network using gradient descent on input dataset

    Args:
        tensor_list: List of tensors encoding the network being trained
        train_data:  The data used to train the network
        loss_fun:    Scalar-valued loss function of the type 
                        tens_list, data -> scalar_loss
                     (This depends on the task being learned)
        epochs:      Number of epochs to train for. When val_data is given,
                     setting epochs=None implements early stopping
        val_data:    The data used for validation
        other_args:  Dictionary of other arguments for the optimization, 
                     with some options below (feel free to add more)

                        optim: Choice of Pytorch optimizer (default='SGD')
                        lr:    Learning rate for optimizer (default=1e-3)
                        bsize: Minibatch size for training (default=100)
                        reps:  Number of times to repeat 
                               training data per epoch     (default=1)
                        print: Whether to print info       (default=True)
                        dyn_print: use dynamic printing    (default=False)
                        hist:  Whether to return losses
                               from train and val sets     (default=False)
                        momentum: Momentum value for 
                                  continuous optimization  (default=0)
                        cvg_threshold: threshold to test convergence of 
                            optimization (optimization is stopped if 
                            |(prev_loss - cur_loss)/prev_loss| <  cvg_threshold
                            If None, convergence is not checked. If epochs is
                            set as well, then optimziation is stopped either when
                            convergence criteria is met or when epochs is reached
                                                            (default:None)
                        lr_scheduler: a function taking an optimizer as input
                        and returning a learning rate scheduler for this optimizer
                                                            (default:None)
                        save_optimizer_state: if True, other_args should have an empty
                            dict for the key optimizer_state. This dict will contain 
                              {optimizer_state: optimizer state_dict,
                              lr_scheduler_state: scheduler state_dict (if any)}
                            after the function returns.     (default:False)
                        load_optimzer_state: a dictionnary that will be used to 
                            initialize the optimizer (and scheduler if any) from a
                            previously saved optimizer state.
                                                            (default: None)
                        grad_masking_function: a function taking the list of tensor
                            parameters between the backward pass and the optimizer step
                            (can be used to e.g. zero out parts of the gradient)
                                                            (default: None)
                        stop_condition: a function taking the training and validation loss
                            as input after each epoch and returning True if optimization 
                            should be stopped               (default: None)

    
    Returns:
        better_list: List of tensors with same shape as tensor_list, but
                     having been optimized using the appropriate optimizer.
                     When validation data is given, the model with the 
                     lowest validation loss is output, otherwise the model
                     with lowest training loss
        first_loss:  Initial loss of the model on the validation set, 
                     before any training. If no val set is provided, the
                     first training loss is instead returned
        best_loss:   The value of the validation/training loss for the
                     model output as better_list
        best_epoch:  epoch at which best_model was found
        loss_record: If hist=True in other_args, history of all validation
                     and training losses is returned as a tuple of Pytorch
                     vectors (train_loss, val_loss), with each vector
                     having length equal to number of epochs of training.
                     When no validation loss is provided, the second item
                     (val_loss) is an empty tensor.
    """
    # Check input and initialize local record variables
    early_stop = epochs is None
    has_val = val_data is not None
    optim = other_args['optim'] if 'optim' in other_args else 'SGD'
    lr = other_args['lr'] if 'lr' in other_args else 1e-3
    bsize = other_args['bsize'] if 'bsize' in other_args else 100
    reps = other_args['reps'] if 'reps' in other_args else 1
    prnt = other_args['print'] if 'print' in other_args else True
    hist = other_args['hist'] if 'hist' in other_args else False
    dyn_print = other_args['dyn_print'] if 'dyn_print' in other_args else False
    lr_scheduler = other_args[
        'lr_scheduler'] if 'lr_scheduler' in other_args else None
    cvg_threshold = other_args[
        'cvg_threshold'] if 'cvg_threshold' in other_args else None
    save_optimizer_state = other_args[
        'save_optimizer_state'] if 'save_optimizer_state' in other_args else None
    load_optimizer_state = other_args[
        'load_optimizer_state'] if 'load_optimizer_state' in other_args else None
    grad_masking_function = other_args[
        'grad_masking_function'] if 'grad_masking_function' in other_args else None
    momentum = other_args['momentum'] if 'momentum' in other_args else 0

    stop_condition = other_args[
        'stop_condition'] if 'stop_condition' in other_args else None

    if save_optimizer_state and (not 'optimizer_state' in other_args):
        raise ValueError(
            "an empty dictionnary should be passed as the optimizer_state argument to store the"
            " optimizer state.")
    if early_stop and not has_val:
        raise ValueError("Early stopping (epochs=None) requires val_data "
                         "to be input")
    loss_rec, first_loss, best_loss, best_network, best_epoch = [], None, np.infty, tensor_list, 0
    if hist: loss_record = ([], [])  # (train_record, val_record)

    # Function to maybe print, conditioned on `prnt`
    m_print = lambda s: print(s, end='\r'
                              if dyn_print else '\n') if prnt else None

    # Function to record loss information and return whether to stop
    def record_loss(new_loss, new_network, epoch_num):
        # Load record variables from outer scope
        nonlocal loss_rec, first_loss, best_loss, best_network, best_epoch

        # Check for first and best loss
        if best_loss is None or new_loss < best_loss:
            best_loss, best_network, best_epoch = new_loss, new_network, epoch_num
        if first_loss is None:
            first_loss = new_loss

        # Update loss record and check for early stopping. If you want to
        # change early stopping criteria, this is the place to do it.
        window = 2  # Number of epochs kept for checking early stopping
        warmup = 1  # Number of epochs before early stopping is checked
        if len(loss_rec) < window:
            stop, loss_rec = False, loss_rec + [new_loss]
        else:
            # stop = new_loss > sum(loss_rec)/len(loss_rec)
            stop = (new_loss > max(loss_rec)) and (epoch_num >= warmup)
            loss_rec = loss_rec[1:] + [new_loss]

        return stop

    # Another loss logging function, but for recording *all* loss history
    @torch.no_grad()
    def loss_history(new_loss, is_val):
        if not hist: return
        nonlocal loss_record
        loss_record[int(is_val)].append(new_loss)

    # Function to run TN on validation data
    @torch.no_grad()
    def run_val(t_list):
        val_loss = []

        # Note that `batchify` uses different logic for different types
        # of input, so update batchify when you work on tensor completion
        for batch in batchify(val_data):
            val_loss.append(loss_fun(t_list, batch))
        if has_val:
            val_loss = torch.mean(torch.tensor(val_loss))

        return val_loss

    # Copy tensor_list so the original is unchanged
    tensor_list = copy_network(tensor_list)

    # Record the initial validation loss (if we validation dataset)
    if has_val: record_loss(run_val(tensor_list), tensor_list, 0)

    # Initialize optimizer, using only the keyword args in the
    optim = getattr(torch.optim, optim)
    opt_args = signature(optim).parameters.keys()
    kwargs = {'lr': lr, 'momentum': momentum}  # <- Add new options here
    kwargs = {k: v for (k, v) in kwargs.items() if k in opt_args}
    optim = optim(tensor_list, **kwargs)  # Initialize the optimizer
    if lr_scheduler:  # instantiate learning rate scheduler
        scheduler = lr_scheduler(optim)

    if load_optimizer_state:
        optim.load_state_dict(
            other_args["load_optimizer_state"]["optimizer_state"])
        if lr_scheduler:
            scheduler.load_state_dict(
                other_args["load_optimizer_state"]["lr_scheduler_state"])

    # Loop over validation and training for given number of epochs
    ep = 1
    prev_loss = np.infty

    while epochs is None or ep <= epochs:

        # Train network on all the training data
        #from copy import deepcopy
        prev_tensor_list = copy_network(tensor_list)
        #prev_tensor_list = tensor_list
        train_loss, num_train = 0., 0
        for batch in batchify(train_data, batch_size=bsize, reps=reps):
            loss = loss_fun(tensor_list, batch)
            optim.zero_grad()
            loss.backward()
            if grad_masking_function:
                grad_masking_function(tensor_list)
            optim.step()

            with torch.no_grad():
                num_train += 1
                train_loss += loss

        train_loss /= num_train

        if lr_scheduler:
            scheduler.step(train_loss)

        loss_history(train_loss, is_val=False)

        val_loss = run_val(tensor_list) if has_val else None

        val_loss_str = f"Val. loss:  {val_loss.data:.10f}" if has_val else ""
        m_print(
            f"EPOCH {ep} {'('+str(reps)+' reps)' if reps > 1 else ''}\t\t{val_loss_str}\t\t Train loss: {train_loss.data:.10f}\t\t Convergence: {np.abs(train_loss-prev_loss)/prev_loss:.10f}"
        )

        # Get validation loss if we have it, otherwise record training loss
        if has_val:
            # Get and record validation loss, check early stopping condition
            loss_history(val_loss, is_val=True)
            if record_loss(
                    val_loss,
                    copy_network(tensor_list) if has_val else prev_tensor_list,
                    ep) and early_stop:
                print(f"\nEarly stopping condition reached")
                break
        else:
            record_loss(
                train_loss,
                copy_network(tensor_list) if has_val else prev_tensor_list, ep)

        if cvg_threshold and np.abs(train_loss -
                                    prev_loss) / prev_loss < cvg_threshold:
            print(f"\nConvergence criteria reached")
            break
        if stop_condition and stop_condition(train_loss=train_loss,
                                             val_loss=val_loss):
            print(f"\nStopping condition reached")
            break

        prev_loss = train_loss

        ep += 1
    m_print("")

    # Save the optimizer state if needed
    if save_optimizer_state:
        other_args["optimizer_state"]["optimizer_state"] = optim.state_dict()
        if lr_scheduler:
            other_args["optimizer_state"][
                "lr_scheduler_state"] = scheduler.state_dict()

    if hist:
        loss_record = tuple(torch.tensor(fr) for fr in loss_record)
        return best_network, first_loss, best_loss, best_epoch, loss_record
    else:
        return best_network, first_loss, best_loss
示例#10
0
    def __init__(self):
        super(MAML_coat, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('l1', nn.Linear(600,300)),
            ('relu1', nn.ReLU()),
            ('l2', nn.Linear(300,1500)),
            ('relu2', nn.ReLU()),
            ('l3', nn.Linear(150,70)),
            ('relu3', nn.ReLU()),
            ('l4', nn.Linear(70,20),
            ('relu4', nn.ReLU()),
            ('l5', nn.Linear(20,1))
        ]))
        
    def forward(self, x):
        return self.model(x)

    def parameterised(self, x, weights):
        # like forward, but uses ``weights`` instead of ``model.parameters()``
        # it'd be nice if this could be generated automatically for any nn.Module...
        x = nn.functional.linear(x, weights[0], weights[1])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[2], weights[3])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[4], weights[5])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[6], weights[7])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[8], weights[9])
        return x
                 


# %%
class MAMLModel(nn.Module):
    def __init__(self):
        super(MAMLModel, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('l1', nn.Linear(1200,600)),
            ('relu1', nn.ReLU()),
            ('l2', nn.Linear(600,200)),
            ('relu2', nn.ReLU()),
            ('l3', nn.Linear(200,100)),
            ('relu3', nn.ReLU()),
            ('l4', nn.Linear(100,50)),
            ('relu4', nn.ReLU()),
            ('l5', nn.Linear(50,25)),
            ('relu5', nn.ReLU()),
            ('l6', nn.Linear(25,1))
        ]))
        
    def forward(self, x):
        return self.model(x)
    
    def parameterised(self, x, weights):
        # like forward, but uses ``weights`` instead of ``model.parameters()``
        # it'd be nice if this could be generated automatically for any nn.Module...
        x = nn.functional.linear(x, weights[0], weights[1])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[2], weights[3])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[4], weights[5])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[6], weights[7])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[8], weights[9])
        x = nn.functional.relu(x)
        x = nn.functional.linear(x, weights[10], weights[11])
        return x
                        
# %%
class MAML():
    def __init__(self, model, inner_lr, meta_lr, K=10, inner_steps=1, tasks_per_meta_batch=1000):
        
        # important objects
#        self.tasks = tasks
        self.model = model
        self.weights = list(model.parameters()) # the maml weights we will be meta-optimising
        self.criterion = nn.MSELoss()
        self.meta_optimiser = torch.optim.Adam(self.weights, meta_lr)
        
        # hyperparameters
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.K = K
        self.inner_steps = inner_steps # with the current design of MAML, >1 is unlikely to work well 
        self.tasks_per_meta_batch = tasks_per_meta_batch 
        
        # metrics
        self.plot_every = 10
        self.print_every = 500
        self.meta_losses = []
    
    def inner_loop(self):
        # reset inner model to current maml weights
        temp_weights = [w.clone() for w in self.weights]
        
        # perform training on data sampled from task
        X, y = data_loader_rand(self.K)
        for step in range(self.inner_steps):
            loss = self.criterion(self.model.parameterised(X, temp_weights), y) / self.K
            
            # compute grad and update inner loop weights
            grad = torch.autograd.grad(loss, temp_weights)
            temp_weights = [w - self.inner_lr * g for w, g in zip(temp_weights, grad)]
        
        # sample new data for meta-update and compute loss
        X, y = data_loader_rand(self.K)
        loss = self.criterion(self.model.parameterised(X, temp_weights), y) / self.K
        
        return loss
     
    def main_loop(self, num_iterations):
        epoch_loss = 0
        
        for iteration in range(1, num_iterations+1):
            
            # compute meta loss
            meta_loss = 0
            for i in range(self.tasks_per_meta_batch):
                #task = self.tasks.sample_task()
                #meta_loss += self.inner_loop(task)
                meta_loss += self.inner_loop()
            
            # compute meta gradient of loss with respect to maml weights
            meta_grads = torch.autograd.grad(meta_loss, self.weights)
            
            # assign meta gradient to weights and take optimisation step
            for w, g in zip(self.weights, meta_grads):
                w.grad = g
            self.meta_optimiser.step()
            
            # log metrics
            epoch_loss += meta_loss.item() / self.tasks_per_meta_batch
            
            if iteration % self.print_every == 0:
                print("{}/{}. loss: {}".format(iteration, num_iterations, epoch_loss / self.plot_every))
            
            if iteration % self.plot_every == 0:
                self.meta_losses.append(epoch_loss / self.plot_every)
                epoch_loss = 0
# %%
tasks = Sine_Task_Distribution(0.1, 5, 0, np.pi, -5, 5)
maml = MAML(MAMLModel(), tasks, inner_lr=0.01, meta_lr=0.001)
# %%
maml.main_loop(num_iterations=10000)

#%%
plt.plot(maml.meta_losses)


#%%
def loss_on_random_task(initial_model, K, num_steps, optim=torch.optim.SGD):
    """
    trains the model on a random sine task and measures the loss curve.
    
    for each n in num_steps_measured, records the model function after n gradient updates.
    """
    
    # copy MAML model into a new object to preserve MAML weights during training
    model = nn.Sequential(OrderedDict([
        ('l1', nn.Linear(1200,600)),
        ('relu1', nn.ReLU()),
        ('l2', nn.Linear(600,200)),
        ('relu2', nn.ReLU()),
        ('l3', nn.Linear(200,100)),
        ('relu3', nn.ReLU()),
        ('l4', nn.Linear(100,50)),
        ('relu4', nn.ReLU()),
        ('l5', nn.Linear(50,25)),
        ('relu5', nn.ReLU()),
        ('l6', nn.Linear(25,1))
    ]))
    
    model.load_state_dict(initial_model.state_dict())
    criterion = nn.MSELoss()
    optimiser = optim(model.parameters(), 0.01)

    # train model on a random task
    task = tasks.sample_task()
    X, y = task.sample_data(K)
    losses = []
    for step in range(1, num_steps+1):
        loss = criterion(model(X), y) / K
        losses.append(loss.item())

        # compute grad and update inner loop weights
        model.zero_grad()
        loss.backward()
        optimiser.step()
        
    return losses


#%%
def average_losses(initial_model, n_samples, K=10, n_steps=10, optim=torch.optim.SGD):
    """
    returns the average learning trajectory of the model trained for ``n_iterations`` over ``n_samples`` tasks
    """

    x = np.linspace(-5, 5, 2) # dummy input for test_on_new_task
    avg_losses = [0] * K
    for i in range(n_samples):
        losses = loss_on_random_task(initial_model, K, n_steps, optim)
        avg_losses = [l + l_new for l, l_new in zip(avg_losses, losses)]
    avg_losses = [l / n_samples for l in avg_losses]
    
    return avg_losses

#%%
def mixed_pretrained(iterations=500):
    """
    returns a model pretrained on a selection of ``iterations`` random tasks.
    """
    
    # set up model
    model = nn.Sequential(OrderedDict([
        ('l1', nn.Linear(1200,600)),
        ('relu1', nn.ReLU()),
        ('l2', nn.Linear(600,200)),
        ('relu2', nn.ReLU()),
        ('l3', nn.Linear(200,100)),
        ('relu3', nn.ReLU()),
        ('l4', nn.Linear(100,50)),
        ('relu4', nn.ReLU()),
        ('l5', nn.Linear(50,25)),
        ('relu5', nn.ReLU()),
        ('l6', nn.Linear(25,1))
    ]))
    optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.MSELoss()
    
    # fit the model
    for i in range(iterations):
        
        model.zero_grad()
        x, y = tasks.sample_task().sample_data(10)
        loss = criterion(model(x), y)
        loss.backward()
        optimiser.step()
        
    return model

#%%
pretrained = mixed_pretrained(10000)

plt.plot(average_losses(maml.model.model, n_samples=5000, K=10), label='maml')
plt.plot(average_losses(pretrained,       n_samples=5000, K=10), label='pretrained')
plt.legend()
plt.title("Average learning trajectory for K=10, starting from initial weights")
plt.xlabel("gradient steps taken with SGD")
plt.show()

#%%
plt.plot(average_losses(maml.model.model, n_samples=5000, K=10, optim=torch.optim.Adam), label='maml')
plt.plot(average_losses(pretrained,       n_samples=5000, K=10, optim=torch.optim.Adam), label='pretrained')
plt.legend()
plt.title("Average learning trajectory for K=10, starting from initial weights")
plt.xlabel("gradient steps taken with Adam")
plt.show()

#%%
def model_functions_at_training(initial_model, X, y, sampled_steps, x_axis, optim=torch.optim.SGD, lr=0.01):
    """
    trains the model on X, y and measures the loss curve.
    
    for each n in sampled_steps, records model(x_axis) after n gradient updates.
    """
    
    # copy MAML model into a new object to preserve MAML weights during training
    model = nn.Sequential(OrderedDict([
        ('l1', nn.Linear(1200,600)),
        ('relu1', nn.ReLU()),
        ('l2', nn.Linear(600,200)),
        ('relu2', nn.ReLU()),
        ('l3', nn.Linear(200,100)),
        ('relu3', nn.ReLU()),
        ('l4', nn.Linear(100,50)),
        ('relu4', nn.ReLU()),
        ('l5', nn.Linear(50,25)),
        ('relu5', nn.ReLU()),
        ('l6', nn.Linear(25,1))
    ]))
    model.load_state_dict(initial_model.state_dict())
    criterion = nn.MSELoss()
    optimiser = optim(model.parameters(), lr)

    # train model on a random task
    num_steps = max(sampled_steps)
    K = X.shape[0]
    
    losses = []
    outputs = {}
    for step in range(1, num_steps+1):
        loss = criterion(model(X), y) / K
        losses.append(loss.item())

        # compute grad and update inner loop weights
        model.zero_grad()
        loss.backward()
        optimiser.step()

        # plot the model function
        if step in sampled_steps:
            outputs[step] = model(torch.tensor(x_axis, dtype=torch.float).view(-1, 1)).detach().numpy()
            
    outputs['initial'] = initial_model(torch.tensor(x_axis, dtype=torch.float).view(-1, 1)).detach().numpy()
    
    return outputs, losses

#%%
def plot_sampled_performance(initial_model, model_name, task, X, y, optim=torch.optim.SGD, lr=0.01):
    
    x_axis = np.linspace(-5, 5, 1000)
    sampled_steps=[1,10]
    outputs, losses = model_functions_at_training(initial_model, 
                                                  X, y, 
                                                  sampled_steps=sampled_steps, 
                                                  x_axis=x_axis, 
                                                  optim=optim, lr=lr)

    plt.figure(figsize=(15,5))
    
    # plot the model functions
    plt.subplot(1, 2, 1)
    
    plt.plot(x_axis, task.true_function(x_axis), '-', color=(0, 0, 1, 0.5), label='true function')
    plt.scatter(X, y, label='data')
    plt.plot(x_axis, outputs['initial'], ':', color=(0.7, 0, 0, 1), label='initial weights')
    
    for step in sampled_steps:
        plt.plot(x_axis, outputs[step], 
                 '-.' if step == 1 else '-', color=(0.5, 0, 0, 1),
                 label='model after {} steps'.format(step))
        
    plt.legend(loc='lower right')
    plt.title("Model fit: {}".format(model_name))

    # plot losses
    plt.subplot(1, 2, 2)
    plt.plot(losses)
    plt.title("Loss over time")
    plt.xlabel("gradient steps taken")
    plt.show()

#%%
K = 10
task = tasks.sample_task()
X, y = task.sample_data(K)

plot_sampled_performance(maml.model.model, 'MAML', task, X, y)

#%%
plot_sampled_performance(pretrained, 'pretrained at lr=0.02', task, X, y, lr=0.02)

#%%
K = 5
task = tasks.sample_task()
X, y = task.sample_data(K)

plot_sampled_performance(maml.model.model, 'MAML', task, X, y)
示例#11
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('mode', choices=['train', 'predict'])
    arg('run_root')
    arg('trained_weight')
    arg('--model', default='densenet169')
    arg('--optimizer', default='adamw')
    arg('--pretrained', type=int, default=1)
    arg('--warmingup', type=int, default=1)
    arg('--batch-size', type=int, default=32)
    arg('--step', type=int, default=1)
    arg('--workers', type=int, default=2 if ON_KAGGLE else 4)
    arg('--lr', type=float, default=1e-4)
    arg('--patience', type=int, default=3)
    arg('--epochs', type=int, default=100)
    arg('--epoch-size', type=int)
    arg('--clean', action='store_true')
    arg('--tta', type=int, default=4)
    arg('--debug', action='store_true')
    args = parser.parse_args()

    run_root = Path(args.run_root)
    trained_weight = Path(args.trained_weight)
    seed_everything(seed=2333)

    model_conv = getattr(models, args.model)(num_classes=N_CLASSES,
                                             pretrained=args.pretrained)
    model_conv.cuda()
    optim = getattr(optimizers, args.optimizer)(params=model_conv.parameters())

    if run_root.exists() and args.clean:
        shutil.rmtree(run_root)

    run_root.mkdir(exist_ok=True, parents=True)
    (run_root / 'params.json').write_text(
        json.dumps(vars(args), indent=4, sort_keys=True))
    batch_size = args.batch_size,
    num_workers = args.workers

    if trained_weight.exists():
        if (os.path.exists(trained_weight)):
            model_conv.load_state_dict(torch.load(trained_weight),
                                       strict=False)

    if args.mode == 'train':

        train_loader, valid_loader = read_data(run_root, batch_size,
                                               num_workers)

        if args.warmingup:
            model_conv.freeze_basemodel()
            n_epochs = 1
            p = 0
            valid_loss_min = float("inf")
            optimizer = optim(lr=1e-2)
            for i in range(n_epochs):
                start_time = time.time()
                avg_loss, avg_corrects, avg_auc = train(i)
                avg_val_loss, avg_val_corrects = test()
                elapsed_time = time.time() - start_time
                print('Epoch {}/{} \t loss={:.4f} acc={:.4f} auc={:.4f} \t val_loss={:.4f} val_acc={:.4f} \t time={:.2f}s'.format(\
                    i + 1, n_epochs, avg_loss, avg_corrects, avg_auc, avg_val_loss, avg_val_corrects, elapsed_time))

        optimizer = optim(lr=args.lr)
        model_conv.unfreeze_model()
        n_epochs = args.epochs
        patience = args.patience
        for i in range(n_epochs):
            start_time = time.time()
            avg_loss, avg_corrects, avg_auc = train(model_conv, train_loader,
                                                    i)
            avg_val_loss, avg_val_corrects = test(model_conv, valid_loader)
            elapsed_time = time.time() - start_time
            print('Epoch {}/{} \t loss={:.4f} acc={:.4f} auc={:.4f} \t val_loss={:.4f} val_acc={:.4f} \t time={:.2f}s'.format(\
                i + 1, n_epochs, avg_loss, avg_corrects, avg_auc, avg_val_loss, avg_val_corrects, elapsed_time))

            if avg_val_loss <= valid_loss_min:
                print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(\
                valid_loss_min,avg_val_loss))
                torch.save(model_conv.state_dict(), 'model.pt')
                valid_loss_min = avg_val_loss
                p = 0

            if avg_val_loss > valid_loss_min:
                p += 1
                print(f'{p} epochs of increasing val loss')
                if p >= 1:
                    print('Decrease learning rate')
                    optimizer = optim(lr=args.lr / 10)
                if p > patience:
                    print('Stopping training')
                    stop = True
                    break
    else:
        num_tta = args.tta
        test_loader = read_test(root_path, batch_size, num_workers)
        model_conv.eval().cuda()

        for tta in range(num_tta):
            preds = []
            for batch_i, (data, target) in enumerate(test_loader):
                data, target = data.cuda(), target.cuda()
                output = model_conv(data).detach()

                pr = output[:, 0].cpu().numpy()
                for i in pr:
                    preds.append(i)

            test_preds = pd.DataFrame({
                'imgs': test_set.image_files_list,
                'preds': preds
            })
            test_preds['imgs'] = test_preds['imgs'].apply(
                lambda x: x.split('.')[0])
            sub = pd.read_csv(f'{run_root}/data/sample_submission.csv')
            sub = pd.merge(sub, test_preds, left_on='id', right_on='imgs')
            sub = sub[['id', 'preds']]
            sub.columns = ['id', 'label']
            sub.head()
            sub.to_csv('single_model_' + str(tta) + '.csv', index=False)

        del model_conv
        gc.collect()
        torch.cuda.empty_cache()
示例#12
0
def train(network,
          train_dataloader,
          test_dataloader,
          trainDataset,
          testDataset,
          vocab,
          epochs,
          learning_range=0.001,
          load_checkpoint=False,
          checkpoint_file=CHECKPOINT_FILE,
          criterion=nn.CrossEntropyLoss(),
          optim=torch.optim.Adam):

    print("TRAIN STARTED!")
    log_file = open(TRAIN_LOG_FILE, 'w')

    train_loss_epochs = []
    test_loss_epochs = []
    optimizer = optim(network.parameters(), lr=learning_range)
    best_test_score = 10**6

    if load_checkpoint:
        open_checkpoint(network,
                        optimizer,
                        is_best=False,
                        filename=checkpoint_file)
    try:
        for epoch in range(epochs):

            train_loss_sum = 0.0
            train_loss_count = 0

            sample_id = 0
            print("Epoch: {} Training".format(epoch + 1))
            for sample in tqdm(train_dataloader):

                sample_id += 1
                torch.cuda.empty_cache()

                features = sample['image']
                ann_ids = sample['anns']
                batch_size = features.shape[0]
                lengths = sample['ann_len']
                max_len = lengths.max()

                captions = load_anns(trainDataset,
                                     ann_ids,
                                     max_len,
                                     prepare=lambda w: vocab(w))
                captions = captions.long()

                lengths, perm_index = lengths.sort(0, descending=True)
                lengths = lengths.numpy()
                captions = captions[perm_index]
                features = features[perm_index]

                captions = Variable(captions)
                features = Variable(features)
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                outputs = network.forward(features, captions, lengths)
                targets = targets.cuda()
                #               outputs = outputs.cpu()
                #               targets = targets.cpu()

                loss_batch = criterion(outputs, targets)
                train_loss_sum += loss_batch.data[0]
                train_loss_count += 1.0

                loss_batch.backward()
                optimizer.step()
                optimizer.zero_grad()
                del features, captions, loss_batch, sample, outputs, targets, lengths
                if sample_id % 200 == 0:
                    gc.collect()

            gc.collect()
            train_loss_epochs.append(train_loss_sum / train_loss_count)

            test_loss_sum = 0.0
            test_loss_count = 0
            sample_id = 0
            torch.cuda.empty_cache()

            print("Epoch: {} Testing".format(epoch + 1))
            for sample in tqdm(test_dataloader):
                sample_id += 1
                torch.cuda.empty_cache()

                features = sample['image']
                ann_ids = sample['anns']
                batch_size = features.shape[0]
                lengths = sample['ann_len']
                max_len = lengths.max()

                captions = load_anns(testDataset,
                                     ann_ids,
                                     max_len,
                                     prepare=lambda w: vocab(w))
                captions = captions.long()

                lengths, perm_index = lengths.sort(0, descending=True)
                lengths = lengths.numpy()
                captions = captions[perm_index]
                features = features[perm_index]

                captions = Variable(captions).cuda()
                features = Variable(features).cuda()
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                outputs = network.forward(features, captions, lengths)
                targets = targets.cuda()
                #               outputs = outputs.cpu()
                #              targets = targets.cpu()

                loss_batch = criterion(outputs, targets)
                test_loss_sum += loss_batch.data[0]
                test_loss_count += 1.0

                del features, captions, loss_batch, sample, outputs, targets, lengths
                if sample_id % 200 == 0:
                    gc.collect()

            test_loss_epochs.append(test_loss_sum / (test_loss_count))

            test_network(network, testDataset, vocab, epoch)

            is_best = test_loss_epochs[-1] < best_test_score
            best_test_score = min(test_loss_epochs[-1], best_test_score)
            save_checkpoint(
                {
                    'net': network,
                    'epoch': epoch + 1,
                    'state_dict': network.state_dict(),
                    'best_test_score': best_test_score,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                filename='checkpoints/checkpoint_{}.pth.tar'.format(epoch + 1))
            log_file.write(
                '\rEpoch {0}... (Train/Test) Loss: {1:.3f}/{2:.3f}\n'.format(
                    epoch, train_loss_epochs[-1], test_loss_epochs[-1]))

            sys.stdout.write(
                '\rEpoch {0}... (Train/Test) Loss: {1:.3f}/{2:.3f}\n'.format(
                    epoch, train_loss_epochs[-1], test_loss_epochs[-1]))
            gc.collect()

    except KeyboardInterrupt:
        pass
    # plt.figure(figsize=(12, 5))
    # plt.plot(train_loss_epochs[1:], label='Train')
    # plt.plot(test_loss_epochs[1:], label='Test')
    # plt.xlabel('Epochs', fontsize=16)
    # plt.ylabel('Loss', fontsize=16)
    # plt.legend(loc=0, fontsize=16)
    # plt.grid('on')
    # plt.savefig(TRAIN_PLT_FILE)
    #
    gc.collect()

    print("TRAIN ENDED!")
def main(_run, _config, _log):
    settings = Settings()
    common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], settings[
        'NETWORK'], settings['TRAINING'], settings['EVAL']

    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.cuda.set_device(device=_config['gpu_id'])
    torch.set_num_threads(1)

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'BCV':
        make_data = meta_data
    else:
        print(f"data name : {data_name}")
        raise ValueError('Wrong config for dataset!')

    tr_dataset, val_dataset, ts_dataset = make_data(_config)
    trainloader = DataLoader(
        dataset=tr_dataset,
        batch_size=_config['batch_size'],
        shuffle=True,
        num_workers=_config['n_work'],
        pin_memory=False, #True load data while training gpu
        drop_last=True
    )
    _log.info('###### Create model ######')
    model = fs.FewShotSegmentorDoubleSDnet(net_params).cuda()
    model.train()

    _log.info('###### Set optimizer ######')
    optim = torch.optim.Adam
    optim_args = {"lr": train_params['learning_rate'],
                  "weight_decay": train_params['optim_weight_decay'],}
                  # "momentum": train_params['momentum']}
    optim_c = optim(list(model.conditioner.parameters()), **optim_args)
    optim_s = optim(list(model.segmentor.parameters()), **optim_args)
    scheduler_s = lr_scheduler.StepLR(optim_s, step_size=100, gamma=0.1)
    scheduler_c = lr_scheduler.StepLR(optim_c, step_size=100, gamma=0.1)
    criterion = losses.DiceLoss()

    if _config['record']:  ## tensorboard visualization
        _log.info('###### define tensorboard writer #####')
        _log.info(f'##### board/train_{_config["board"]}_{date()}')
        writer = SummaryWriter(f'board/train_{_config["board"]}_{date()}')

    iter_print = _config["iter_print"]
    iter_n_train = len(trainloader)
    _log.info('###### Training ######')
    for i_epoch in range(_config['n_steps']):
        epoch_loss = 0
        for i_iter, sample_batched in enumerate(trainloader):
            # Prepare input
            s_x = sample_batched['s_x'].cuda()  # [B, Support, slice_num=1, 1, 256, 256]
            X = s_x.squeeze(2)  # [B, Support, 1, 256, 256]
            s_y = sample_batched['s_y'].cuda()  # [B, Support, slice_num, 1, 256, 256]
            Y = s_y.squeeze(2)  # [B, Support, 1, 256, 256]
            Y = Y.squeeze(2)  # [B, Support, 256, 256]
            q_x = sample_batched['q_x'].cuda()  # [B, slice_num, 1, 256, 256]
            query_input = q_x.squeeze(1)  # [B, 1, 256, 256]
            q_y = sample_batched['q_y'].cuda()  # [B, slice_num, 1, 256, 256]
            y2 = q_y.squeeze(1)  # [B, 1, 256, 256]
            y2 = y2.squeeze(1)  # [B, 256, 256]
            y2 = y2.type(torch.LongTensor).cuda()

            entire_weights = []
            for shot_id in range(_config["n_shot"]):
                input1 = X[:, shot_id, ...] # use 1 shot at first
                y1 = Y[:, shot_id, ...] # use 1 shot at first
                condition_input = torch.cat((input1, y1.unsqueeze(1)), dim=1)
                weights = model.conditioner(condition_input) # 2, 10, [B, channel=1, w, h]
                entire_weights.append(weights)

            # pdb.set_trace()
            avg_weights=[[],[None, None, None, None]]
            for i in range(9):
                weight_cat = torch.cat([weights[0][i] for weights in entire_weights],dim=1)
                avg_weight = torch.mean(weight_cat,dim=1,keepdim=True)
                avg_weights[0].append(avg_weight)

            avg_weights[0].append(None)

            output = model.segmentor(query_input, avg_weights)
            loss = criterion(F.softmax(output, dim=1), y2)
            optim_s.zero_grad()
            optim_c.zero_grad()
            loss.backward()
            optim_s.step()
            optim_c.step()

            epoch_loss += loss
            if iter_print:
                print(f"train, iter:{i_iter}/{iter_n_train}, iter_loss:{loss}", end='\r')

        scheduler_c.step()
        scheduler_s.step()
        print(f'step {i_epoch+1}: loss: {epoch_loss}                               ')

        if _config['record']:
            batch_i = 0
            frames = []
            query_pred = output.argmax(dim=1)
            query_pred = query_pred.unsqueeze(1)
            frames += overlay_color(q_x[batch_i,0], query_pred[batch_i].float(), q_y[batch_i,0])
            # frames += overlay_color(s_xi[batch_i], blank, s_yi[batch_i], scale=_config['scale'])
            visual = make_grid(frames, normalize=True, nrow=2)
            writer.add_image("train/visual", visual, i_epoch)

        save_fname = f'{_run.observers[0].dir}/snapshots/last.pth'
        torch.save(model.state_dict(),save_fname)
示例#14
0
def main():
    # Get arguments
    args = parse_args()

    # Set random seed
    torch.manual_seed(args.seed)

    # Cuda
    use_cuda = False
    if torch.cuda.is_available():
        if not args.cuda:
            print("WARNING: You have a CUDA device, so you \
            should probably run with --cuda")
        else:
            use_cuda = True
            torch.cuda.manual_seed(args.seed)

    # Load data + text fields
    print('=' * 89)
    train_iter, val_iter, test_iter, SRC, TRG = utils.load_dataset(
        batch_size=args.batch_size,
        use_pretrained_emb=args.pretrained_emb,
        save_dir=SAVE_DIR
    )
    print('=' * 89)

    # Intialize model
    enc = models.EncoderRNN(
        input_size=len(SRC.vocab),
        emb_size=(SRC.vocab.vectors.size(1)
                  if args.pretrained_emb == 'fastText'
                  else args.emb_size),
        embeddings=(SRC.vocab.vectors
                    if args.pretrained_emb == 'fastText'
                    else None),
        max_norm=args.emb_maxnorm,
        padding_idx=SRC.vocab.stoi['<pad>'],
        hidden_size=args.hidden_size,
        num_layers=args.num_layers,
        dropout=args.dropout,
        bidirectional=args.bidirectional
    )
    decoder = models.AttnDecoderRNN if args.attention else models.DecoderRNN
    dec = decoder(
        enc_num_directions=enc.num_directions,
        enc_hidden_size=args.hidden_size,
        use_context=args.use_context,
        input_size=len(TRG.vocab),
        emb_size=(TRG.vocab.vectors.size(1)
                  if args.pretrained_emb
                  else args.emb_size),
        embeddings=(TRG.vocab.vectors
                    if args.pretrained_emb
                    else None),
        max_norm=args.emb_maxnorm,
        padding_idx=TRG.vocab.stoi['<pad>'],
        hidden_size=args.hidden_size,
        num_layers=args.num_layers,
        dropout=args.dropout,
        bidirectional=False # args.bidirectional
    )
    model = models.Seq2Seq(enc, dec, use_cuda=use_cuda)
    if use_cuda:
        model.cuda()
    print(model)

    # Intialize loss
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=TRG.vocab.stoi["<pad>"])

    # Create optimizer
    if args.optimizer == 'Adam':
        optim = torch.optim.Adam
    elif args.optimizer == 'Adadelta':
        optim = torch.optim.Adadelta
    elif args.optimizer == 'Adagrad':
        optim = torch.optim.Adagrad
    else:
        optim = torch.optim.SGD
    optimizer = optim(model.parameters(), lr=args.lr)

    # Create scheduler
    lambda_lr = lambda epoch: 0.5 if epoch > 8 else 1
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda_lr)

    # Train
    best_val_loss = None
    fname = './{}/{}.pt'.format(SAVE_DIR, args.save)

    print('=' * 89)
    try:
        for epoch in range(1, args.epochs+1):
            epoch_start_time = time.time()

            attns = train(epoch, model, train_iter, criterion, optimizer,
                  use_cuda, args, SRC, TRG)
            val_loss = evaluate(model, val_iter, criterion, use_cuda)

            # Log results
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s '
                  '| valid loss {:5.2f} | valid ppl {:8.2f}'.format(
                      epoch, (time.time() - epoch_start_time),
                      val_loss, math.exp(val_loss)))
            print('-' * 89)

            # Save the model if validation loss is best we've seen so far
            if not best_val_loss or val_loss < best_val_loss:
                if not os.path.isdir(SAVE_DIR):
                    os.makedirs(SAVE_DIR)
                torch.save(model, fname)
                best_val_loss = val_loss

            # Anneal learning rate
            scheduler.step()
    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

    # Load the best saved model
    with open(fname, 'rb') as f:
        model = torch.load(f)

    # Run on test data
    test_loss = evaluate(model, test_iter, criterion, use_cuda)

    # Log results
    print('=' * 89)
    print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
        test_loss, math.exp(test_loss)))
    print('=' * 89)