Exemplo n.º 1
0
 def test_construction_model(self):
     mlp = MLP(hidden_layer_sizes=(2, ))
     # number of data = 3
     # number of feature = 2
     x_train = np.zeros([3, 2])
     y_train = [1, 0, 1]
     mlp.train(x_train, y_train)
     self.assertTrue(mlp.input.output_list[0].output_list[0].output_list[0].
                     name.find('softmax') < 0)
Exemplo n.º 2
0
def train(args):

    perturb_mock, sgRNA_list_mock = makedata.json_to_perturb_data(path = "/home/member/xywang/WORKSPACE/MaryGUO/one-shot/MOCK_MON_crispr_combine/crispr_analysis")

    total = sc.read_h5ad("/home/member/xywang/WORKSPACE/MaryGUO/one-shot/mock_one_perturbed.h5ad")
    trainset, testset = preprocessing.make_total_data(total,sgRNA_list_mock)

    TrainSet = perturbdataloader(trainset, ways = args.num_ways, support_shots = args.num_shots, query_shots = 15)
    TrainLoader = DataLoader(TrainSet, batch_size=args.batch_size_train, shuffle=False,num_workers=args.num_workers)

    model = MLP(out_features = args.num_ways)

    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(TrainLoader, total=args.num_batches) as pbar:
        for batch_idx, (inputs_support, inputs_query, target_support, target_query) in enumerate(pbar):
            model.zero_grad()

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            inputs_query = inputs_query.to(device=args.device)
            target_query = target_query.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(zip(inputs_support, target_support,inputs_query, target_query)):

                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size_train)
            accuracy.div_(args.batch_size_train)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches or accuracy.item() > 0.95:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'maml_omniglot_'
                                                    '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)

    # start test
    test_support, test_query, test_target_support, test_target_query \
        = helpfuntions.sample_once(testset,support_shot=args.num_shots, shuffle=False,plus = len(trainset))
    test_query = torch.from_numpy(test_query).to(device=args.device)
    test_target_query = torch.from_numpy(test_target_query).to(device=args.device)

    TrainSet = perturbdataloader_test(test_support, test_target_support)
    TrainLoader = DataLoader(TrainSet, args.batch_size_test)

    meta_optimizer.zero_grad()
    inner_losses = []
    accuracy_test = []

    for epoch in range(args.num_epoch):
        model.to(device=args.device)
        model.train()

        for _, (inputs_support,target_support) in enumerate(TrainLoader):

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            train_logit = model(inputs_support)
            loss = F.cross_entropy(train_logit, target_support)
            inner_losses.append(loss)
            loss.backward()
            meta_optimizer.step()
            meta_optimizer.zero_grad()

            test_logit = model(test_query)
            with torch.no_grad():
                accuracy = get_accuracy(test_logit, test_target_query)
                accuracy_test.append(accuracy)



        if (epoch + 1) % 3 == 0:
            print('Epoch [{}/{}], Loss: {:.4f},accuray: {:.4f}'.format(epoch + 1, args.num_epoch, loss,accuracy))