def run(args, num_workers=1, log_interval=100, verbose=True, save_path=None): utils.set_seed(args.seed) # --------------------------------------------------------- # -------------------- load data --------------------------- perturb_mock, sgRNA_list_mock = makedata.json_to_perturb_data( path = "/home/member/xywang/WORKSPACE/MaryGUO/one-shot/MOCK_MON_crispr_combine/crispr_analysis") total = sc.read_h5ad("/home/member/xywang/WORKSPACE/MaryGUO/one-shot/mock_one_perturbed.h5ad") trainset, testset = preprocessing.make_total_data(total, sgRNA_list_mock) # --------------------------------------------------------- # -------------------- training --------------------------- # initialise model model = MLP(num_classes = args.n_way, input_size = args.dimension, num_context_params = args.num_context_params, num_film_hidden_layers = args.num_film_hidden_layers, context_in = args.context_in, initialisation=args.nn_initialisation, hidden_size1 = args.hidden_size1, hidden_size2 = args.hidden_size2, device = args.device ) model.train() # set up meta-optimiser for model parameters meta_optimiser = torch.optim.Adam(model.parameters(), args.lr_meta) scheduler = torch.optim.lr_scheduler.StepLR(meta_optimiser, 5000, args.lr_meta_decay) # initialise logger logger = Logger(log_interval, args, verbose = verbose) # initialise the starting point for the meta gradient (it's faster to copy this than to create new object) meta_grad_init = [0 for _ in range(len(model.state_dict()))] iter_counter = 0 while iter_counter < args.n_iter: # batchsz here means total episode number dataset_train = perturbdataloader(trainset, args.tasks_per_metaupdate * args.n_iter, n_ways = args.n_way, k_shots = args.k_shot, k_query = args.k_query) # fetch meta_batchsz num of episode each time dataloader_train = DataLoader(dataset_train, args.tasks_per_metaupdate, shuffle = False, num_workers=num_workers, pin_memory=False) # initialise dataloader dataset_valid = perturbdataloader(testset, 500 * args.tasks_per_metaupdate, n_ways = args.n_way, k_shots = args.k_shot, k_query = args.k_query, plus=len(trainset)) dataloader_valid = DataLoader(dataset_valid, args.tasks_per_metaupdate, shuffle = False, num_workers = num_workers, pin_memory = False) logger.print_header() # meta_optimiser.zero_grad() for step, (support_x, support_y, query_x, query_y) in enumerate(dataloader_train): #meta_optimiser.step() scheduler.step() support_x = support_x.to(args.device) support_y = support_y.to(args.device) query_x = query_x.to(args.device) query_y = query_y.to(args.device) # skip batch if we don't have enough tasks in the current batch (might happen in last batch) if support_x.shape[0] != args.tasks_per_metaupdate: continue # initialise meta-gradient meta_grad = copy.deepcopy(meta_grad_init) logger.prepare_inner_loop(iter_counter) for inner_batch_idx in range(args.tasks_per_metaupdate): # reset context parameters model.reset_context_params() # -------------- inner update -------------- logger.log_pre_update(iter_counter, support_x[inner_batch_idx], support_y[inner_batch_idx], query_x[inner_batch_idx], query_y[inner_batch_idx], model) for time in range(args.num_grad_steps_inner): # forward train data through net pred_train = model(support_x[inner_batch_idx]) # compute loss task_loss_train = F.cross_entropy(pred_train, support_y[inner_batch_idx]) # compute gradient for context parameters task_grad_train = torch.autograd.grad(task_loss_train, model.context_params, create_graph=True)[0] # set context parameters to their updated values model.context_params = model.context_params - args.lr_inner * task_grad_train # -------------- get meta gradient -------------- # forward test data through updated net pred_test = model(query_x[inner_batch_idx]) # compute loss on test data task_loss_test = F.cross_entropy(pred_test, query_y[inner_batch_idx]) # compute gradient for shared parameters task_grad_test = torch.autograd.grad(task_loss_test, model.parameters()) # add to meta-gradient for g in range(len(task_grad_test)): meta_grad[g] += task_grad_test[g].detach() # ------------------------------------------------ logger.log_post_update(iter_counter, support_x[inner_batch_idx], support_y[inner_batch_idx], query_x[inner_batch_idx], query_y[inner_batch_idx], model) # reset context parameters model.reset_context_params() # summarise inner loop and get validation performance logger.summarise_inner_loop(mode='train') if iter_counter % log_interval == 0: # evaluate how good the current model is (*before* updating so we can compare better) evaluate(iter_counter, args, model, logger, dataloader_valid, save_path) if save_path is not None: np.save(save_path, [logger.training_stats, logger.validation_stats]) # save model to CPU save_model = model if args.device == 'cuda:0': save_model = copy.deepcopy(model).to(torch.args.device('cpu')) torch.save(save_model, save_path) logger.print(iter_counter, task_grad_train, meta_grad) iter_counter += 1 if iter_counter > args.n_iter: break # -------------- meta update -------------- meta_optimiser.zero_grad() # set gradients of parameters manually for c, param in enumerate(model.parameters()): param.grad = meta_grad[c] / float(args.tasks_per_metaupdate) param.grad.data.clamp_(-10, 10) # the meta-optimiser only operates on the shared parameters, not the context parameters meta_optimiser.step() model.reset_context_params() return logger, model