Example #1
0
def train_model(data_set_identifier, train_file, val_file, learning_rate,
                minibatch_size):
    set_experiment_id(data_set_identifier, learning_rate, minibatch_size)

    train_loader = contruct_dataloader_from_disk(train_file, minibatch_size)
    validation_loader = contruct_dataloader_from_disk(val_file, minibatch_size)
    validation_dataset_size = validation_loader.dataset.__len__()

    model = ExampleModel(21, minibatch_size,
                         use_gpu=use_gpu)  # embed size = 21

    # TODO: is soft_to_angle.parameters() included here?
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    sample_num = list()
    train_loss_values = list()
    validation_loss_values = list()
    rmsd_avg_values = list()
    drmsd_avg_values = list()

    best_model_loss = 1.1
    best_model_minibatch_time = None
    best_model_path = None
    stopping_condition_met = False
    minibatches_proccesed = 0

    while not stopping_condition_met:
        optimizer.zero_grad()
        model.zero_grad()
        loss_tracker = np.zeros(0)
        for minibatch_id, training_minibatch in enumerate(train_loader, 0):
            minibatches_proccesed += 1
            primary_sequence, tertiary_positions, mask = training_minibatch
            start_compute_loss = time.time()
            loss = model.compute_loss(primary_sequence, tertiary_positions)
            write_out("Train loss:", float(loss))
            start_compute_grad = time.time()
            loss.backward()
            loss_tracker = np.append(loss_tracker, float(loss))
            end = time.time()
            write_out("Loss time:", start_compute_grad - start_compute_loss,
                      "Grad time:", end - start_compute_grad)
            optimizer.step()
            optimizer.zero_grad()
            model.zero_grad()

            # for every eval_interval samples, plot performance on the validation set
            if minibatches_proccesed % args.eval_interval == 0:

                train_loss = loss_tracker.mean()
                loss_tracker = np.zeros(0)
                validation_loss, data_total, rmsd_avg, drmsd_avg = evaluate_model(
                    validation_loader, model)
                prim = data_total[0][0]
                pos = data_total[0][1]
                (aa_list, phi_list, psi_list,
                 omega_list) = calculate_dihedral_angels(prim, pos)
                write_to_pdb(
                    get_structure_from_angles(aa_list, phi_list[1:],
                                              psi_list[:-1], omega_list[:-1]),
                    "test")
                cmd.load("output/protein_test.pdb")
                write_to_pdb(data_total[0][3], "test_pred")
                cmd.load("output/protein_test_pred.pdb")
                cmd.forward()
                cmd.orient()
                if validation_loss < best_model_loss:
                    best_model_loss = validation_loss
                    best_model_minibatch_time = minibatches_proccesed
                    best_model_path = write_model_to_disk(model)

                write_out("Validation loss:", validation_loss, "Train loss:",
                          train_loss)
                write_out("Best model so far (label loss): ", validation_loss,
                          "at time", best_model_minibatch_time)
                write_out("Best model stored at " + best_model_path)
                write_out("Minibatches processed:", minibatches_proccesed)
                sample_num.append(minibatches_proccesed)
                train_loss_values.append(train_loss)
                validation_loss_values.append(validation_loss)
                rmsd_avg_values.append(rmsd_avg)
                drmsd_avg_values.append(drmsd_avg)
                if args.live_plot:
                    data = {}
                    data["validation_dataset_size"] = validation_dataset_size
                    data["sample_num"] = sample_num
                    data["train_loss_values"] = train_loss_values
                    data["validation_loss_values"] = validation_loss_values
                    data["phi_actual"] = list(
                        [math.degrees(float(v)) for v in phi_list[1:]])
                    data["psi_actual"] = list(
                        [math.degrees(float(v)) for v in psi_list[:-1]])
                    data["phi_predicted"] = list([
                        math.degrees(float(v)) for v in data_total[0]
                        [2].detach().transpose(0, 1)[0][1:]
                    ])
                    data["psi_predicted"] = list([
                        math.degrees(float(v)) for v in data_total[0]
                        [2].detach().transpose(0, 1)[1][:-1]
                    ])
                    data["drmsd_avg"] = drmsd_avg_values
                    data["rmsd_avg"] = rmsd_avg_values
                    res = requests.post('http://localhost:5000/graph',
                                        json=data)
                    if res.ok:
                        print(res.json())

                if minibatches_proccesed > args.minimum_updates and minibatches_proccesed > best_model_minibatch_time * 2:
                    stopping_condition_met = True
                    break
    write_result_summary(best_model_loss)
    return best_model_path
Example #2
0
def train_model(data_set_identifier, train_file, val_file, learning_rate,
                minibatch_size):
    set_protien_experiments_id(data_set_identifier, learning_rate,
                               minibatch_size)

    train_loader = contruct_data_loader_from_disk(train_file, minibatch_size)
    validation_loader = contruct_data_loader_from_disk(val_file,
                                                       minibatch_size)
    validation_dataset_size = validation_loader.dataset.__len__()

    model = ExampleModel(9, "ONEHOT", minibatch_size, use_gpu=use_gpu)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # plot settings
    if live_plot:
        plt.ion()
    fig = plt.figure()
    sample_num = list()
    train_loss_values = list()
    validation_loss_values = list()

    best_model_loss = 1.1
    best_model_minibatch_time = None
    best_model_path = None
    stopping_condition_met = False
    minibatches_proccesed = 0

    while not stopping_condition_met:
        optimizer.zero_grad()
        model.zero_grad()
        loss_tracker = np.zeros(0)
        for minibatch_id, training_minibatch in enumerate(train_loader, 0):
            minibatches_proccesed += 1
            primary_sequence, tertiary_positions, mask = training_minibatch
            start_compute_loss = time.time()
            loss = model.neg_log_likelihood(primary_sequence,
                                            tertiary_positions)
            write_out("Train loss:", float(loss))
            start_compute_grad = time.time()
            loss.backward()
            loss_tracker = np.append(loss_tracker, float(loss))
            end = time.time()
            write_out("Loss time:", start_compute_grad - start_compute_loss,
                      "Grad time:", end - start_compute_grad)
            optimizer.step()
            optimizer.zero_grad()
            model.zero_grad()

            # for every eval_interval samples, plot performance on the validation set
            if minibatches_proccesed % eval_interval == 0:

                train_loss = loss_tracker.mean()
                loss_tracker = np.zeros(0)
                validation_loss, data_total = test_eval_model(
                    validation_loader, model)
                prim = data_total[0][0]
                pos = data_total[0][1].transpose(0, 1).contiguous().view(-1, 3)
                pos_predicted = data_total[0][2].transpose(
                    0, 1).contiguous().view(-1, 3)
                write_to_pdb_strcture(pos, prim, "test")
                print(pos)
                # cmd.load("output/protein_test.pdb")
                write_to_pdb_strcture(pos_predicted.detach(), prim,
                                      "test_pred")
                if validation_loss < best_model_loss:
                    best_model_loss = validation_loss
                    best_model_minibatch_time = minibatches_proccesed
                    best_model_path = save_model_on_disk_torch_version(model)

                write_out("Validation loss:", validation_loss, "Train loss:",
                          train_loss)
                write_out("Best model so far (label loss): ", validation_loss,
                          "at time", best_model_minibatch_time)
                write_out("Best model stored at " + best_model_path)
                write_out("Minibatches processed:", minibatches_proccesed)
                sample_num.append(minibatches_proccesed)
                train_loss_values.append(train_loss)
                validation_loss_values.append(validation_loss)
                if live_plot:
                    drawnow(
                        draw_plot(fig, plt, validation_dataset_size,
                                  sample_num, train_loss_values,
                                  validation_loss_values))

                if minibatches_proccesed > minimum_updates and minibatches_proccesed > best_model_minibatch_time * 2:
                    stopping_condition_met = True
                    break
    logs(best_model_loss)
    return best_model_path