예제 #1
0
def baseline(args):
    args = parse_arguments(args)

    if not os.path.isfile(args.model_path):
        raise Exception('Model path is invalid')

    if args.device:
        device = torch.device(args.device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_helper = TrainHelper(device)
    train_helper.seed_torch(seed=args.seed)

    checkpoint = torch.load(args.model_path)

    args.sender_path = None
    args.receiver_path = None
    # get sender and receiver models and save them
    sender, _, _ = get_sender_receiver(device, args)
    sender.load_state_dict(checkpoint['sender'])
    sender.greedy = True

    model = get_trainer(
        sender,
        device,
        dataset_type="raw")

    model.visual_module.load_state_dict(checkpoint['visual_module'])

    train_data, validation_data, test_data = get_shapes_dataloader(
        device=device,
        batch_size=1024,
        k=3,
        debug=False,
        dataset_type="raw")

    model.to(device)

    model.eval()

    if not os.path.exists(args.output_path):
        os.mkdir(args.output_path)

    sample_messages_from_dataset(model, args, train_data, 'train')
    sample_messages_from_dataset(model, args, validation_data, 'validation')
    sample_messages_from_dataset(model, args, test_data, 'test')
예제 #2
0
def baseline(args):

    args = parse_arguments(args)

    if args.device:
        device = torch.device(args.device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # file_helper = FileHelper()
    train_helper = TrainHelper(device)
    train_helper.seed_torch(seed=args.seed)

    model = ImageReceiver()
    model.to(device)

    unique_name = generate_unique_name(
        length=args.max_length,
        vocabulary_size=args.vocab_size,
        seed=args.messages_seed)

    model_name = generate_model_name(
        length=args.max_length,
        vocabulary_size=args.vocab_size,
        messages_seed=args.messages_seed,
        training_seed=args.seed)

    if not os.path.exists('results'):
        os.mkdir('results')

    reproducing_folder = os.path.join('results', 'reproducing-images')
    if not os.path.exists(reproducing_folder):
        os.mkdir(reproducing_folder)

    output_path = os.path.join(reproducing_folder, model_name)
    if not os.path.exists(output_path):
        os.mkdir(output_path)

    model_path = os.path.join(output_path, 'messages_receiver.p')

    print(f'loading messages using unique name: "{unique_name}"')

    train_dataset = MessageDataset(unique_name, DatasetType.Train)
    train_dataloader = data.DataLoader(train_dataset, num_workers=1, pin_memory=True, shuffle=True, batch_size=args.batch_size)

    validation_dataset = MessageDataset(unique_name, DatasetType.Valid)
    validation_dataloader = data.DataLoader(validation_dataset, num_workers=1, pin_memory=True, shuffle=False, batch_size=args.batch_size)

    test_dataset = MessageDataset(unique_name, DatasetType.Test)
    test_dataloader = data.DataLoader(test_dataset, num_workers=1, pin_memory=True, shuffle=False, batch_size=args.batch_size)

    pytorch_total_params = sum(p.numel() for p in model.parameters())

    # Print info
    print("----------------------------------------")
    # print(
    #     "Model name: {} \n|V|: {}\nL: {}".format(
    #         model_name, args.vocab_size, args.max_length
    #     )
    # )
    # print(sender)
    # print(visual_module)
    print(model)
    print("Total number of parameters: {}".format(pytorch_total_params))


    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    epoch = 0
    current_patience = args.patience
    best_accuracy = -1.
    converged = False

    start_time = time.time()

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    images_path = os.path.join(output_path, 'samples')
    if not os.path.exists(images_path):
        os.mkdir(images_path)

    iteration = load_model(model_path, model, optimizer)

    if args.test_mode:
        test_images_path = os.path.join(images_path, 'test')
        if not os.path.exists(test_images_path):
            os.mkdir(test_images_path)

        test_loss_meter, test_acc_meter = evaluate(model, criterion, test_dataloader, 0, device, test_images_path)
        print(f'TEST results: loss: {test_loss_meter.avg} | accuracy: {test_acc_meter.avg}')
        return


    print(header)
    while iteration < args.iterations:
        for (messages, original_targets) in train_dataloader:
            print(f'{iteration}/{args.iterations}       \r', end='')

            model.train()

            _, _ = perform_iteration(model, criterion, optimizer, messages, original_targets, iteration, device, images_path, False)

            if iteration % args.log_interval == 0:
                valid_loss_meter, valid_acc_meter = evaluate(model, criterion, validation_dataloader, iteration, device, images_path)

                new_best = False
                average_valid_accuracy = valid_acc_meter.avg

                if average_valid_accuracy < best_accuracy:
                    current_patience -= 1

                    if current_patience <= 0:
                        print('Model has converged. Stopping training...')
                        converged = True
                        break
                else:
                    new_best = True
                    best_accuracy = average_valid_accuracy
                    current_patience = args.patience
                    save_model(model_path, model, optimizer, iteration)


                print(log_template.format(
                    time.time()-start_time,
                    epoch,
                    iteration,
                    1 + iteration,
                    args.iterations,
                    100. * (1+iteration) / args.iterations,
                    valid_loss_meter.avg,
                    valid_acc_meter.avg,
                    "BEST" if new_best else ""))

            iteration += 1

        epoch += 1

        if converged:
            break
예제 #3
0
def baseline(args):
    args = parse_arguments(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    file_helper = FileHelper()
    train_helper = TrainHelper(device)
    train_helper.seed_torch(seed=args.seed)

    model_name = train_helper.get_filename_from_baseline_params(args)
    run_folder = file_helper.get_run_folder(args.folder, model_name)

    metrics_helper = MetricsHelper(run_folder, args.seed)

    # get sender and receiver models and save them
    sender, receiver, diagnostic_receiver = get_sender_receiver(device, args)

    sender_file = file_helper.get_sender_path(run_folder)
    receiver_file = file_helper.get_receiver_path(run_folder)
    # torch.save(sender, sender_file)

    if receiver:
        torch.save(receiver, receiver_file)

    model = get_trainer(
        sender,
        device,
        args.dataset_type,
        receiver=receiver,
        diagnostic_receiver=diagnostic_receiver,
        vqvae=args.vqvae,
        rl=args.rl,
        entropy_coefficient=args.entropy_coefficient,
        myopic=args.myopic,
        myopic_coefficient=args.myopic_coefficient,
    )

    model_path = file_helper.create_unique_model_path(model_name)

    best_accuracy = -1.0
    epoch = 0
    iteration = 0

    if args.resume_training or args.test_mode:
        epoch, iteration, best_accuracy = load_model_state(model, model_path)
        print(
            f"Loaded model. Resuming from - epoch: {epoch} | iteration: {iteration} | best accuracy: {best_accuracy}"
        )

    if not os.path.exists(file_helper.model_checkpoint_path):
        print("No checkpoint exists. Saving model...\r")
        torch.save(model.visual_module, file_helper.model_checkpoint_path)
        print("No checkpoint exists. Saving model...Done")

    train_data, valid_data, test_data, valid_meta_data, _ = get_training_data(
        device=device,
        batch_size=args.batch_size,
        k=args.k,
        debugging=args.debugging,
        dataset_type=args.dataset_type,
    )

    train_meta_data, valid_meta_data, test_meta_data = get_meta_data()

    # dump arguments
    pickle.dump(args, open(f"{run_folder}/experiment_params.p", "wb"))

    pytorch_total_params = sum(p.numel() for p in model.parameters())

    if not args.disable_print:
        # Print info
        print("----------------------------------------")
        print(
            "Model name: {} \n|V|: {}\nL: {}".format(
                model_name, args.vocab_size, args.max_length
            )
        )
        print(sender)
        if receiver:
            print(receiver)

        if diagnostic_receiver:
            print(diagnostic_receiver)

        print("Total number of parameters: {}".format(pytorch_total_params))

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Train
    current_patience = args.patience
    best_accuracy = -1.0
    converged = False

    start_time = time.time()

    if args.test_mode:
        test_loss_meter, test_acc_meter, _ = train_helper.evaluate(
            model, test_data, test_meta_data, device, args.rl
        )

        average_test_accuracy = test_acc_meter.avg
        average_test_loss = test_loss_meter.avg

        print(
            f"TEST results: loss: {average_test_loss} | accuracy: {average_test_accuracy}"
        )
        return

    iterations = []
    losses = []
    hinge_losses = []
    rl_losses = []
    entropies = []
    accuracies = []

    while iteration < args.iterations:
        for train_batch in train_data:
            print(f"{iteration}/{args.iterations}       \r", end="")

            ### !!! This is the complete training procedure. Rest is only logging!
            _, _ = train_helper.train_one_batch(
                model, train_batch, optimizer, train_meta_data, device
            )

            if iteration % args.log_interval == 0:

                if not args.rl:
                    valid_loss_meter, valid_acc_meter, _, = train_helper.evaluate(
                        model, valid_data, valid_meta_data, device, args.rl
                    )
                else:
                    valid_loss_meter, hinge_loss_meter, rl_loss_meter, entropy_meter, valid_acc_meter, _ = train_helper.evaluate(
                        model, valid_data, valid_meta_data, device, args.rl
                    )

                new_best = False

                average_valid_accuracy = valid_acc_meter.avg

                if (
                    average_valid_accuracy < best_accuracy
                ):  # No new best found. May lead to early stopping
                    current_patience -= 1

                    if current_patience <= 0:
                        print("Model has converged. Stopping training...")
                        converged = True
                        break
                else:  # new best found. Is saved.
                    new_best = True
                    best_accuracy = average_valid_accuracy
                    current_patience = args.patience
                    save_model_state(model, model_path, epoch, iteration, best_accuracy)

                # Skip for now  <--- What does this comment mean? printing is not disabled, so this will be shown, right?
                if not args.disable_print:

                    if not args.rl:
                        print(
                            "{}/{} Iterations: val loss: {}, val accuracy: {}".format(
                                iteration,
                                args.iterations,
                                valid_loss_meter.avg,
                                valid_acc_meter.avg,
                            )
                        )
                    else:
                        print(
                            "{}/{} Iterations: val loss: {}, val hinge loss: {}, val rl loss: {}, val entropy: {}, val accuracy: {}".format(
                                iteration,
                                args.iterations,
                                valid_loss_meter.avg,
                                hinge_loss_meter.avg,
                                rl_loss_meter.avg,
                                entropy_meter.avg,
                                valid_acc_meter.avg,
                            )
                        )

                iterations.append(iteration)
                losses.append(valid_loss_meter.avg)
                if args.rl:
                    hinge_losses.append(hinge_loss_meter.avg)
                    rl_losses.append(rl_loss_meter.avg)
                    entropies.append(entropy_meter.avg)
                accuracies.append(valid_acc_meter.avg)

            iteration += 1
            if iteration >= args.iterations:
                break

        epoch += 1

        if converged:
            break

    # prepare writing of data
    dir_path = os.path.dirname(os.path.realpath(__file__))
    dir_path = dir_path.replace("/baseline", "")
    timestamp = str(datetime.datetime.now())
    filename = "output_data/vqvae_{}_rl_{}_dc_{}_gs_{}_dln_{}_dld_{}_beta_{}_entropy_coefficient_{}_myopic_{}_mc_{}_seed_{}_{}.csv".format(
        args.vqvae,
        args.rl,
        args.discrete_communication,
        args.gumbel_softmax,
        args.discrete_latent_number,
        args.discrete_latent_dimension,
        args.beta,
        args.entropy_coefficient,
        args.myopic,
        args.myopic_coefficient,
        args.seed,
        timestamp,
    )
    full_filename = os.path.join(dir_path, filename)

    # write data
    d = [iterations, losses, hinge_losses, rl_losses, entropies, accuracies]
    export_data = zip_longest(*d, fillvalue="")
    with open(full_filename, "w", encoding="ISO-8859-1", newline="") as myfile:
        wr = csv.writer(myfile)
        wr.writerow(
            ("iteration", "loss", "hinge loss", "rl loss", "entropy", "accuracy")
        )
        wr.writerows(export_data)
    myfile.close()

    # plotting
    print(filename)
    plot_data(filename, args)

    return run_folder
예제 #4
0
def baseline(args):
    args = parse_arguments(args)

    if args.device:
        device = torch.device(args.device)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    inference_path = os.path.join('data', 'inference')
    if not os.path.exists(inference_path):
        os.mkdir(inference_path)

    unique_name = generate_unique_name(length=args.max_length,
                                       vocabulary_size=args.vocab_size,
                                       seed=args.messages_seed)

    model_name = generate_model_name(length=args.max_length,
                                     vocabulary_size=args.vocab_size,
                                     messages_seed=args.messages_seed,
                                     training_seed=args.training_seed)

    model_path = os.path.join(inference_path, model_name)

    train_helper = TrainHelper(device)
    train_helper.seed_torch(args.training_seed)

    train_dataset = DiagnosticDataset(unique_name, DatasetType.Train)
    train_dataloader = data.DataLoader(train_dataset,
                                       num_workers=1,
                                       pin_memory=True,
                                       shuffle=True,
                                       batch_size=args.batch_size)

    validation_dataset = DiagnosticDataset(unique_name, DatasetType.Valid)
    validation_dataloader = data.DataLoader(validation_dataset,
                                            num_workers=1,
                                            pin_memory=True,
                                            shuffle=False,
                                            batch_size=args.batch_size)

    test_dataset = DiagnosticDataset(unique_name, DatasetType.Test)
    test_dataloader = data.DataLoader(test_dataset,
                                      num_workers=1,
                                      pin_memory=True,
                                      shuffle=False,
                                      batch_size=args.batch_size)

    diagnostic_model = initialize_model(args, device)

    diagnostic_model.to(device)

    # Setup the loss and optimizer

    print(header)

    best_accuracy = -1.

    for epoch in range(args.max_epochs):

        # TRAIN
        diagnostic_model.train()
        perform_iteration(diagnostic_model,
                          train_dataloader,
                          args.batch_size,
                          device,
                          sample_count=args.training_sample_count)

        # VALIDATION

        diagnostic_model.eval()
        validation_accuracies_meter, validation_losses_meter = perform_iteration(
            diagnostic_model, validation_dataloader, args.batch_size, device)

        new_best = False
        if validation_accuracies_meter.avg > best_accuracy:
            new_best = True
            best_accuracy = validation_accuracies_meter.avg
            save_model(model_path, diagnostic_model)

        print_results(validation_accuracies_meter, validation_losses_meter,
                      epoch, args.max_epochs, "validation", new_best)

    best_model = initialize_model(args, device, model_path)
    best_model.eval()

    test_accuracies_meter, test_losses_meter = perform_iteration(
        best_model, test_dataloader, args.batch_size, device)
    print_results(test_accuracies_meter, test_losses_meter, epoch,
                  args.max_epochs, "test", False)
예제 #5
0
def baseline(args):
    args = parse_arguments(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    file_helper = FileHelper()
    train_helper = TrainHelper(device)
    train_helper.seed_torch(seed=args.seed)

    model_name = train_helper.get_filename_from_baseline_params(args)
    run_folder = file_helper.get_run_folder(args.folder, model_name)

    logger = Logger(run_folder, print_logs=(not args.disable_print))
    logger.log_args(args)

    # get sender and receiver models and save them
    sender, receiver, diagnostic_receiver = get_sender_receiver(device, args)

    sender_file = file_helper.get_sender_path(run_folder)
    receiver_file = file_helper.get_receiver_path(run_folder)
    # torch.save(sender, sender_file)

    if receiver:
        torch.save(receiver, receiver_file)

    model = get_trainer(
        sender,
        device,
        args.dataset_type,
        receiver=receiver,
        diagnostic_receiver=diagnostic_receiver,
        vqvae=args.vqvae,
        rl=args.rl,
        entropy_coefficient=args.entropy_coefficient,
        myopic=args.myopic,
        myopic_coefficient=args.myopic_coefficient,
    )

    model_path = file_helper.create_unique_model_path(model_name)

    best_accuracy = -1.0
    epoch = 0
    iteration = 0

    if args.resume_training or args.test_mode:
        epoch, iteration, best_accuracy = load_model_state(model, model_path)
        print(
            f"Loaded model. Resuming from - epoch: {epoch} | iteration: {iteration} | best accuracy: {best_accuracy}"
        )

    if not os.path.exists(file_helper.model_checkpoint_path):
        print("No checkpoint exists. Saving model...\r")
        torch.save(model.visual_module, file_helper.model_checkpoint_path)
        print("No checkpoint exists. Saving model...Done")

    train_data, valid_data, test_data, valid_meta_data, _ = get_training_data(
        device=device,
        batch_size=args.batch_size,
        k=args.k,
        debugging=args.debugging,
        dataset_type=args.dataset_type,
    )

    train_meta_data, valid_meta_data, test_meta_data = get_meta_data()

    pytorch_total_params = sum(p.numel() for p in model.parameters())

    if not args.disable_print:
        # Print info
        print("----------------------------------------")
        print("Model name: {} \n|V|: {}\nL: {}".format(model_name,
                                                       args.vocab_size,
                                                       args.max_length))
        print(sender)
        if receiver:
            print(receiver)

        if diagnostic_receiver:
            print(diagnostic_receiver)

        print("Total number of parameters: {}".format(pytorch_total_params))

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # Train
    current_patience = args.patience
    best_accuracy = -1.0
    converged = False

    start_time = time.time()

    if args.test_mode:
        test_loss_meter, test_acc_meter, _ = train_helper.evaluate(
            model, test_data, test_meta_data, device, args.rl)

        average_test_accuracy = test_acc_meter.avg
        average_test_loss = test_loss_meter.avg

        print(
            f"TEST results: loss: {average_test_loss} | accuracy: {average_test_accuracy}"
        )
        return

    while iteration < args.iterations:
        for train_batch in train_data:
            print(f"{iteration}/{args.iterations}       \r", end="")

            # !!! This is the complete training procedure. Rest is only logging!
            _, _ = train_helper.train_one_batch(model, train_batch, optimizer,
                                                train_meta_data, device)

            if iteration % args.log_interval == 0:

                if not args.rl:
                    valid_loss_meter, valid_acc_meter, _, = train_helper.evaluate(
                        model, valid_data, valid_meta_data, device, args.rl)
                else:
                    valid_loss_meter, hinge_loss_meter, rl_loss_meter, entropy_meter, valid_acc_meter, _ = train_helper.evaluate(
                        model, valid_data, valid_meta_data, device, args.rl)

                new_best = False

                average_valid_accuracy = valid_acc_meter.avg

                if (average_valid_accuracy < best_accuracy
                    ):  # No new best found. May lead to early stopping
                    current_patience -= 1

                    if current_patience <= 0:
                        print("Model has converged. Stopping training...")
                        converged = True
                        break
                else:  # new best found. Is saved.
                    new_best = True
                    best_accuracy = average_valid_accuracy
                    current_patience = args.patience
                    save_model_state(model, model_path, epoch, iteration,
                                     best_accuracy)

                metrics = {
                    'loss': valid_loss_meter.avg,
                    'accuracy': valid_acc_meter.avg,
                }
                if args.rl:
                    metrics['hinge loss'] = hinge_loss_meter.avg
                    metrics['rl loss'] = rl_loss_meter.avg
                    metrics['entropy'] = entropy_meter.avg

                logger.log_metrics(iteration, metrics)

            iteration += 1
            if iteration >= args.iterations:
                break

        epoch += 1

        if converged:
            break

    return run_folder