Beispiel #1
0
def run(args, num_workers=1, log_interval=100, verbose=True, save_path=None):
    utils.set_seed(args.seed)

    # ---------------------------------------------------------
    # -------------------- load data ---------------------------

    perturb_mock, sgRNA_list_mock = makedata.json_to_perturb_data(
        path = "/home/member/xywang/WORKSPACE/MaryGUO/one-shot/MOCK_MON_crispr_combine/crispr_analysis")

    total = sc.read_h5ad("/home/member/xywang/WORKSPACE/MaryGUO/one-shot/mock_one_perturbed.h5ad")
    trainset, testset = preprocessing.make_total_data(total, sgRNA_list_mock)

    # ---------------------------------------------------------
    # -------------------- training ---------------------------

    # initialise model
    model = MLP(num_classes = args.n_way,
                input_size = args.dimension,
                num_context_params = args.num_context_params,
                num_film_hidden_layers = args.num_film_hidden_layers,
                context_in = args.context_in,
                initialisation=args.nn_initialisation,
                hidden_size1 = args.hidden_size1,
                hidden_size2 = args.hidden_size2,
                device = args.device
                )

    model.train()

    # set up meta-optimiser for model parameters
    meta_optimiser = torch.optim.Adam(model.parameters(), args.lr_meta)
    scheduler = torch.optim.lr_scheduler.StepLR(meta_optimiser, 5000, args.lr_meta_decay)

    # initialise logger
    logger = Logger(log_interval, args, verbose = verbose)

    # initialise the starting point for the meta gradient (it's faster to copy this than to create new object)
    meta_grad_init = [0 for _ in range(len(model.state_dict()))]

    iter_counter = 0
    while iter_counter < args.n_iter:

        # batchsz here means total episode number
        dataset_train = perturbdataloader(trainset,
                                          args.tasks_per_metaupdate * args.n_iter,
                                          n_ways = args.n_way,
                                          k_shots = args.k_shot,
                                          k_query = args.k_query)

        # fetch meta_batchsz num of episode each time
        dataloader_train = DataLoader(dataset_train, args.tasks_per_metaupdate, shuffle = False,
                                      num_workers=num_workers,
                                      pin_memory=False)

        # initialise dataloader
        dataset_valid = perturbdataloader(testset,
                                          500 * args.tasks_per_metaupdate,
                                          n_ways = args.n_way,
                                          k_shots = args.k_shot,
                                          k_query = args.k_query,
                                          plus=len(trainset))
        dataloader_valid = DataLoader(dataset_valid, args.tasks_per_metaupdate, shuffle = False,
                                      num_workers = num_workers,
                                      pin_memory = False)

        logger.print_header()

        # meta_optimiser.zero_grad()

        for step, (support_x, support_y, query_x, query_y) in enumerate(dataloader_train):

            #meta_optimiser.step()
            scheduler.step()

            support_x = support_x.to(args.device)
            support_y = support_y.to(args.device)
            query_x = query_x.to(args.device)
            query_y = query_y.to(args.device)

            # skip batch if we don't have enough tasks in the current batch (might happen in last batch)
            if support_x.shape[0] != args.tasks_per_metaupdate:
                continue

            # initialise meta-gradient
            meta_grad = copy.deepcopy(meta_grad_init)

            logger.prepare_inner_loop(iter_counter)

            for inner_batch_idx in range(args.tasks_per_metaupdate):

                # reset context parameters
                model.reset_context_params()

                # -------------- inner update --------------

                logger.log_pre_update(iter_counter, support_x[inner_batch_idx], support_y[inner_batch_idx],
                                      query_x[inner_batch_idx], query_y[inner_batch_idx], model)

                for time in range(args.num_grad_steps_inner):
                    # forward train data through net
                    pred_train = model(support_x[inner_batch_idx])

                    # compute loss
                    task_loss_train = F.cross_entropy(pred_train, support_y[inner_batch_idx])

                    # compute gradient for context parameters
                    task_grad_train = torch.autograd.grad(task_loss_train,
                                                          model.context_params,
                                                          create_graph=True)[0]

                    # set context parameters to their updated values
                    model.context_params = model.context_params - args.lr_inner * task_grad_train


                # -------------- get meta gradient --------------

                # forward test data through updated net
                pred_test = model(query_x[inner_batch_idx])

                # compute loss on test data
                task_loss_test = F.cross_entropy(pred_test, query_y[inner_batch_idx])

                # compute gradient for shared parameters
                task_grad_test = torch.autograd.grad(task_loss_test, model.parameters())

                # add to meta-gradient
                for g in range(len(task_grad_test)):
                    meta_grad[g] += task_grad_test[g].detach()

                # ------------------------------------------------
                logger.log_post_update(iter_counter, support_x[inner_batch_idx], support_y[inner_batch_idx],
                                       query_x[inner_batch_idx], query_y[inner_batch_idx], model)

            # reset context parameters
            model.reset_context_params()

            # summarise inner loop and get validation performance
            logger.summarise_inner_loop(mode='train')

            if iter_counter % log_interval == 0:
                # evaluate how good the current model is (*before* updating so we can compare better)
                evaluate(iter_counter, args, model, logger, dataloader_valid, save_path)
                if save_path is not None:
                    np.save(save_path, [logger.training_stats, logger.validation_stats])
                    # save model to CPU
                    save_model = model
                    if args.device == 'cuda:0':
                        save_model = copy.deepcopy(model).to(torch.args.device('cpu'))
                    torch.save(save_model, save_path)

            logger.print(iter_counter, task_grad_train, meta_grad)
            iter_counter += 1

            if iter_counter > args.n_iter:
                break

            # -------------- meta update --------------

            meta_optimiser.zero_grad()
            # set gradients of parameters manually
            for c, param in enumerate(model.parameters()):
                param.grad = meta_grad[c] / float(args.tasks_per_metaupdate)
                param.grad.data.clamp_(-10, 10)

            # the meta-optimiser only operates on the shared parameters, not the context parameters
            meta_optimiser.step()

    model.reset_context_params()
    return logger, model