Пример #1
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # # Build data loader
    # dataset,targets= load_dataset()
    # np.save("__cache_dataset.npy", dataset)
    # np.save("__cache_targets.npy", targets)
    # return

    dataset = np.load("__cache_dataset.npy")
    targets = np.load("__cache_targets.npy")

    # Build the models
    mlp = MLP(args.input_size, args.output_size)

    mlp.load_state_dict(
        torch.load(
            '_backup_model_statedict/mlp_100_4000_PReLU_ae_dd_final.pkl'))

    if torch.cuda.is_available():
        mlp.cuda()

    # Loss and Optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adagrad(mlp.parameters())

    # Train the Models
    total_loss = []
    print(len(dataset))
    print(len(targets))
    sm = 100  # start saving models after 100 epochs
    for epoch in range(args.num_epochs):
        print("epoch" + str(epoch))
        avg_loss = 0
        for i in range(0, len(dataset), args.batch_size):
            # Forward, Backward and Optimize
            mlp.zero_grad()
            bi, bt = get_input(i, dataset, targets, args.batch_size)
            bi = to_var(bi)
            bt = to_var(bt)
            bo = mlp(bi)
            loss = criterion(bo, bt)
            avg_loss = avg_loss + loss.item()
            loss.backward()
            optimizer.step()
        print("--average loss:")
        print(avg_loss / (len(dataset) / args.batch_size))
        total_loss.append(avg_loss / (len(dataset) / args.batch_size))
        # Save the models
        if epoch == sm:
            model_path = 'mlp_100_4000_PReLU_ae_dd' + str(sm) + '.pkl'
            torch.save(mlp.state_dict(),
                       os.path.join(args.model_path, model_path))
            sm = sm + 50  # save model after every 50 epochs from 100 epoch ownwards
    torch.save(total_loss, 'total_loss.dat')
    model_path = 'mlp_100_4000_PReLU_ae_dd_final.pkl'
    torch.save(mlp.state_dict(), os.path.join(args.model_path, model_path))
Пример #2
0
def main():
    # Create a meta optimizer that wraps a model into a meta model
    # to keep track of the meta updates.
    #    meta_model = FullyConnectedNN()
    meta_model = MLP()
    print(meta_model)

    if args.cuda:
        meta_model.cuda()

    meta_optimizer = MetaOptimizer(MetaModel(meta_model), args.num_layers,
                                   args.hidden_size)
    if args.cuda:
        meta_optimizer.cuda()

    optimizer = optim.Adam(meta_optimizer.parameters(), lr=1e-3)

    for epoch in range(args.max_epoch):
        decrease_in_loss = 0.0
        final_loss = 0.0
        train_iter = iter(train_loader)
        for i in range(args.updates_per_epoch):

            # Sample a new model
            #model = FullyConnectedNN()
            model = MLP()
            if args.cuda:
                model.cuda()

            x, y = next(train_iter)
            if args.cuda:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # Compute initial loss of the model
            f_x = model(x)
            initial_loss = F.nll_loss(f_x, y)

            for k in range(args.optimizer_steps // args.truncated_bptt_step):
                # Keep states for truncated BPTT
                meta_optimizer.reset_lstm(keep_states=k > 0,
                                          model=model,
                                          use_cuda=args.cuda)

                loss_sum = 0
                prev_loss = torch.zeros(1)
                if args.cuda:
                    prev_loss = prev_loss.cuda()
                for j in range(args.truncated_bptt_step):
                    x, y = next(train_iter)
                    if args.cuda:
                        x, y = x.cuda(), y.cuda()
                    x, y = Variable(x), Variable(y)

                    # First we need to compute the gradients of the model
                    f_x = model(x)
                    loss = F.nll_loss(f_x, y)
                    model.zero_grad()
                    loss.backward()

                    # Perfom a meta update using gradients from model
                    # and return the current meta model saved in the optimizer
                    meta_model = meta_optimizer.meta_update(model, loss.data)

                    # Compute a loss for a step the meta optimizer
                    f_x = meta_model(x)
                    loss = F.nll_loss(f_x, y)

                    loss_sum += (loss - Variable(prev_loss))

                    prev_loss = loss.data

                # Update the parameters of the meta optimizer
                meta_optimizer.zero_grad()
                loss_sum.backward()
                for param in meta_optimizer.parameters():
                    param.grad.data.clamp_(-1, 1)
                optimizer.step()

            # Compute relative decrease in the loss function w.r.t initial
            # value
            decrease_in_loss += loss.data[0] / initial_loss.data[0]
            final_loss += loss.data[0]

        print("Epoch: {}, final loss {}, average final/initial loss ratio: {}".
              format(epoch, final_loss / args.updates_per_epoch,
                     decrease_in_loss / args.updates_per_epoch))
Пример #3
0
def train(args):

    perturb_mock, sgRNA_list_mock = makedata.json_to_perturb_data(path = "/home/member/xywang/WORKSPACE/MaryGUO/one-shot/MOCK_MON_crispr_combine/crispr_analysis")

    total = sc.read_h5ad("/home/member/xywang/WORKSPACE/MaryGUO/one-shot/mock_one_perturbed.h5ad")
    trainset, testset = preprocessing.make_total_data(total,sgRNA_list_mock)

    TrainSet = perturbdataloader(trainset, ways = args.num_ways, support_shots = args.num_shots, query_shots = 15)
    TrainLoader = DataLoader(TrainSet, batch_size=args.batch_size_train, shuffle=False,num_workers=args.num_workers)

    model = MLP(out_features = args.num_ways)

    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(TrainLoader, total=args.num_batches) as pbar:
        for batch_idx, (inputs_support, inputs_query, target_support, target_query) in enumerate(pbar):
            model.zero_grad()

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            inputs_query = inputs_query.to(device=args.device)
            target_query = target_query.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(zip(inputs_support, target_support,inputs_query, target_query)):

                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size_train)
            accuracy.div_(args.batch_size_train)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches or accuracy.item() > 0.95:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'maml_omniglot_'
                                                    '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)

    # start test
    test_support, test_query, test_target_support, test_target_query \
        = helpfuntions.sample_once(testset,support_shot=args.num_shots, shuffle=False,plus = len(trainset))
    test_query = torch.from_numpy(test_query).to(device=args.device)
    test_target_query = torch.from_numpy(test_target_query).to(device=args.device)

    TrainSet = perturbdataloader_test(test_support, test_target_support)
    TrainLoader = DataLoader(TrainSet, args.batch_size_test)

    meta_optimizer.zero_grad()
    inner_losses = []
    accuracy_test = []

    for epoch in range(args.num_epoch):
        model.to(device=args.device)
        model.train()

        for _, (inputs_support,target_support) in enumerate(TrainLoader):

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            train_logit = model(inputs_support)
            loss = F.cross_entropy(train_logit, target_support)
            inner_losses.append(loss)
            loss.backward()
            meta_optimizer.step()
            meta_optimizer.zero_grad()

            test_logit = model(test_query)
            with torch.no_grad():
                accuracy = get_accuracy(test_logit, test_target_query)
                accuracy_test.append(accuracy)



        if (epoch + 1) % 3 == 0:
            print('Epoch [{}/{}], Loss: {:.4f},accuray: {:.4f}'.format(epoch + 1, args.num_epoch, loss,accuracy))