def baseline(args): args = parse_arguments(args) if not os.path.isfile(args.model_path): raise Exception('Model path is invalid') if args.device: device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_helper = TrainHelper(device) train_helper.seed_torch(seed=args.seed) checkpoint = torch.load(args.model_path) args.sender_path = None args.receiver_path = None # get sender and receiver models and save them sender, _, _ = get_sender_receiver(device, args) sender.load_state_dict(checkpoint['sender']) sender.greedy = True model = get_trainer( sender, device, dataset_type="raw") model.visual_module.load_state_dict(checkpoint['visual_module']) train_data, validation_data, test_data = get_shapes_dataloader( device=device, batch_size=1024, k=3, debug=False, dataset_type="raw") model.to(device) model.eval() if not os.path.exists(args.output_path): os.mkdir(args.output_path) sample_messages_from_dataset(model, args, train_data, 'train') sample_messages_from_dataset(model, args, validation_data, 'validation') sample_messages_from_dataset(model, args, test_data, 'test')
def baseline(args): args = parse_arguments(args) if args.device: device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # file_helper = FileHelper() train_helper = TrainHelper(device) train_helper.seed_torch(seed=args.seed) model = ImageReceiver() model.to(device) unique_name = generate_unique_name( length=args.max_length, vocabulary_size=args.vocab_size, seed=args.messages_seed) model_name = generate_model_name( length=args.max_length, vocabulary_size=args.vocab_size, messages_seed=args.messages_seed, training_seed=args.seed) if not os.path.exists('results'): os.mkdir('results') reproducing_folder = os.path.join('results', 'reproducing-images') if not os.path.exists(reproducing_folder): os.mkdir(reproducing_folder) output_path = os.path.join(reproducing_folder, model_name) if not os.path.exists(output_path): os.mkdir(output_path) model_path = os.path.join(output_path, 'messages_receiver.p') print(f'loading messages using unique name: "{unique_name}"') train_dataset = MessageDataset(unique_name, DatasetType.Train) train_dataloader = data.DataLoader(train_dataset, num_workers=1, pin_memory=True, shuffle=True, batch_size=args.batch_size) validation_dataset = MessageDataset(unique_name, DatasetType.Valid) validation_dataloader = data.DataLoader(validation_dataset, num_workers=1, pin_memory=True, shuffle=False, batch_size=args.batch_size) test_dataset = MessageDataset(unique_name, DatasetType.Test) test_dataloader = data.DataLoader(test_dataset, num_workers=1, pin_memory=True, shuffle=False, batch_size=args.batch_size) pytorch_total_params = sum(p.numel() for p in model.parameters()) # Print info print("----------------------------------------") # print( # "Model name: {} \n|V|: {}\nL: {}".format( # model_name, args.vocab_size, args.max_length # ) # ) # print(sender) # print(visual_module) print(model) print("Total number of parameters: {}".format(pytorch_total_params)) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) epoch = 0 current_patience = args.patience best_accuracy = -1. converged = False start_time = time.time() criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) images_path = os.path.join(output_path, 'samples') if not os.path.exists(images_path): os.mkdir(images_path) iteration = load_model(model_path, model, optimizer) if args.test_mode: test_images_path = os.path.join(images_path, 'test') if not os.path.exists(test_images_path): os.mkdir(test_images_path) test_loss_meter, test_acc_meter = evaluate(model, criterion, test_dataloader, 0, device, test_images_path) print(f'TEST results: loss: {test_loss_meter.avg} | accuracy: {test_acc_meter.avg}') return print(header) while iteration < args.iterations: for (messages, original_targets) in train_dataloader: print(f'{iteration}/{args.iterations} \r', end='') model.train() _, _ = perform_iteration(model, criterion, optimizer, messages, original_targets, iteration, device, images_path, False) if iteration % args.log_interval == 0: valid_loss_meter, valid_acc_meter = evaluate(model, criterion, validation_dataloader, iteration, device, images_path) new_best = False average_valid_accuracy = valid_acc_meter.avg if average_valid_accuracy < best_accuracy: current_patience -= 1 if current_patience <= 0: print('Model has converged. Stopping training...') converged = True break else: new_best = True best_accuracy = average_valid_accuracy current_patience = args.patience save_model(model_path, model, optimizer, iteration) print(log_template.format( time.time()-start_time, epoch, iteration, 1 + iteration, args.iterations, 100. * (1+iteration) / args.iterations, valid_loss_meter.avg, valid_acc_meter.avg, "BEST" if new_best else "")) iteration += 1 epoch += 1 if converged: break
def baseline(args): args = parse_arguments(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") file_helper = FileHelper() train_helper = TrainHelper(device) train_helper.seed_torch(seed=args.seed) model_name = train_helper.get_filename_from_baseline_params(args) run_folder = file_helper.get_run_folder(args.folder, model_name) metrics_helper = MetricsHelper(run_folder, args.seed) # get sender and receiver models and save them sender, receiver, diagnostic_receiver = get_sender_receiver(device, args) sender_file = file_helper.get_sender_path(run_folder) receiver_file = file_helper.get_receiver_path(run_folder) # torch.save(sender, sender_file) if receiver: torch.save(receiver, receiver_file) model = get_trainer( sender, device, args.dataset_type, receiver=receiver, diagnostic_receiver=diagnostic_receiver, vqvae=args.vqvae, rl=args.rl, entropy_coefficient=args.entropy_coefficient, myopic=args.myopic, myopic_coefficient=args.myopic_coefficient, ) model_path = file_helper.create_unique_model_path(model_name) best_accuracy = -1.0 epoch = 0 iteration = 0 if args.resume_training or args.test_mode: epoch, iteration, best_accuracy = load_model_state(model, model_path) print( f"Loaded model. Resuming from - epoch: {epoch} | iteration: {iteration} | best accuracy: {best_accuracy}" ) if not os.path.exists(file_helper.model_checkpoint_path): print("No checkpoint exists. Saving model...\r") torch.save(model.visual_module, file_helper.model_checkpoint_path) print("No checkpoint exists. Saving model...Done") train_data, valid_data, test_data, valid_meta_data, _ = get_training_data( device=device, batch_size=args.batch_size, k=args.k, debugging=args.debugging, dataset_type=args.dataset_type, ) train_meta_data, valid_meta_data, test_meta_data = get_meta_data() # dump arguments pickle.dump(args, open(f"{run_folder}/experiment_params.p", "wb")) pytorch_total_params = sum(p.numel() for p in model.parameters()) if not args.disable_print: # Print info print("----------------------------------------") print( "Model name: {} \n|V|: {}\nL: {}".format( model_name, args.vocab_size, args.max_length ) ) print(sender) if receiver: print(receiver) if diagnostic_receiver: print(diagnostic_receiver) print("Total number of parameters: {}".format(pytorch_total_params)) model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Train current_patience = args.patience best_accuracy = -1.0 converged = False start_time = time.time() if args.test_mode: test_loss_meter, test_acc_meter, _ = train_helper.evaluate( model, test_data, test_meta_data, device, args.rl ) average_test_accuracy = test_acc_meter.avg average_test_loss = test_loss_meter.avg print( f"TEST results: loss: {average_test_loss} | accuracy: {average_test_accuracy}" ) return iterations = [] losses = [] hinge_losses = [] rl_losses = [] entropies = [] accuracies = [] while iteration < args.iterations: for train_batch in train_data: print(f"{iteration}/{args.iterations} \r", end="") ### !!! This is the complete training procedure. Rest is only logging! _, _ = train_helper.train_one_batch( model, train_batch, optimizer, train_meta_data, device ) if iteration % args.log_interval == 0: if not args.rl: valid_loss_meter, valid_acc_meter, _, = train_helper.evaluate( model, valid_data, valid_meta_data, device, args.rl ) else: valid_loss_meter, hinge_loss_meter, rl_loss_meter, entropy_meter, valid_acc_meter, _ = train_helper.evaluate( model, valid_data, valid_meta_data, device, args.rl ) new_best = False average_valid_accuracy = valid_acc_meter.avg if ( average_valid_accuracy < best_accuracy ): # No new best found. May lead to early stopping current_patience -= 1 if current_patience <= 0: print("Model has converged. Stopping training...") converged = True break else: # new best found. Is saved. new_best = True best_accuracy = average_valid_accuracy current_patience = args.patience save_model_state(model, model_path, epoch, iteration, best_accuracy) # Skip for now <--- What does this comment mean? printing is not disabled, so this will be shown, right? if not args.disable_print: if not args.rl: print( "{}/{} Iterations: val loss: {}, val accuracy: {}".format( iteration, args.iterations, valid_loss_meter.avg, valid_acc_meter.avg, ) ) else: print( "{}/{} Iterations: val loss: {}, val hinge loss: {}, val rl loss: {}, val entropy: {}, val accuracy: {}".format( iteration, args.iterations, valid_loss_meter.avg, hinge_loss_meter.avg, rl_loss_meter.avg, entropy_meter.avg, valid_acc_meter.avg, ) ) iterations.append(iteration) losses.append(valid_loss_meter.avg) if args.rl: hinge_losses.append(hinge_loss_meter.avg) rl_losses.append(rl_loss_meter.avg) entropies.append(entropy_meter.avg) accuracies.append(valid_acc_meter.avg) iteration += 1 if iteration >= args.iterations: break epoch += 1 if converged: break # prepare writing of data dir_path = os.path.dirname(os.path.realpath(__file__)) dir_path = dir_path.replace("/baseline", "") timestamp = str(datetime.datetime.now()) filename = "output_data/vqvae_{}_rl_{}_dc_{}_gs_{}_dln_{}_dld_{}_beta_{}_entropy_coefficient_{}_myopic_{}_mc_{}_seed_{}_{}.csv".format( args.vqvae, args.rl, args.discrete_communication, args.gumbel_softmax, args.discrete_latent_number, args.discrete_latent_dimension, args.beta, args.entropy_coefficient, args.myopic, args.myopic_coefficient, args.seed, timestamp, ) full_filename = os.path.join(dir_path, filename) # write data d = [iterations, losses, hinge_losses, rl_losses, entropies, accuracies] export_data = zip_longest(*d, fillvalue="") with open(full_filename, "w", encoding="ISO-8859-1", newline="") as myfile: wr = csv.writer(myfile) wr.writerow( ("iteration", "loss", "hinge loss", "rl loss", "entropy", "accuracy") ) wr.writerows(export_data) myfile.close() # plotting print(filename) plot_data(filename, args) return run_folder
def baseline(args): args = parse_arguments(args) if args.device: device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inference_path = os.path.join('data', 'inference') if not os.path.exists(inference_path): os.mkdir(inference_path) unique_name = generate_unique_name(length=args.max_length, vocabulary_size=args.vocab_size, seed=args.messages_seed) model_name = generate_model_name(length=args.max_length, vocabulary_size=args.vocab_size, messages_seed=args.messages_seed, training_seed=args.training_seed) model_path = os.path.join(inference_path, model_name) train_helper = TrainHelper(device) train_helper.seed_torch(args.training_seed) train_dataset = DiagnosticDataset(unique_name, DatasetType.Train) train_dataloader = data.DataLoader(train_dataset, num_workers=1, pin_memory=True, shuffle=True, batch_size=args.batch_size) validation_dataset = DiagnosticDataset(unique_name, DatasetType.Valid) validation_dataloader = data.DataLoader(validation_dataset, num_workers=1, pin_memory=True, shuffle=False, batch_size=args.batch_size) test_dataset = DiagnosticDataset(unique_name, DatasetType.Test) test_dataloader = data.DataLoader(test_dataset, num_workers=1, pin_memory=True, shuffle=False, batch_size=args.batch_size) diagnostic_model = initialize_model(args, device) diagnostic_model.to(device) # Setup the loss and optimizer print(header) best_accuracy = -1. for epoch in range(args.max_epochs): # TRAIN diagnostic_model.train() perform_iteration(diagnostic_model, train_dataloader, args.batch_size, device, sample_count=args.training_sample_count) # VALIDATION diagnostic_model.eval() validation_accuracies_meter, validation_losses_meter = perform_iteration( diagnostic_model, validation_dataloader, args.batch_size, device) new_best = False if validation_accuracies_meter.avg > best_accuracy: new_best = True best_accuracy = validation_accuracies_meter.avg save_model(model_path, diagnostic_model) print_results(validation_accuracies_meter, validation_losses_meter, epoch, args.max_epochs, "validation", new_best) best_model = initialize_model(args, device, model_path) best_model.eval() test_accuracies_meter, test_losses_meter = perform_iteration( best_model, test_dataloader, args.batch_size, device) print_results(test_accuracies_meter, test_losses_meter, epoch, args.max_epochs, "test", False)
def baseline(args): args = parse_arguments(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") file_helper = FileHelper() train_helper = TrainHelper(device) train_helper.seed_torch(seed=args.seed) model_name = train_helper.get_filename_from_baseline_params(args) run_folder = file_helper.get_run_folder(args.folder, model_name) logger = Logger(run_folder, print_logs=(not args.disable_print)) logger.log_args(args) # get sender and receiver models and save them sender, receiver, diagnostic_receiver = get_sender_receiver(device, args) sender_file = file_helper.get_sender_path(run_folder) receiver_file = file_helper.get_receiver_path(run_folder) # torch.save(sender, sender_file) if receiver: torch.save(receiver, receiver_file) model = get_trainer( sender, device, args.dataset_type, receiver=receiver, diagnostic_receiver=diagnostic_receiver, vqvae=args.vqvae, rl=args.rl, entropy_coefficient=args.entropy_coefficient, myopic=args.myopic, myopic_coefficient=args.myopic_coefficient, ) model_path = file_helper.create_unique_model_path(model_name) best_accuracy = -1.0 epoch = 0 iteration = 0 if args.resume_training or args.test_mode: epoch, iteration, best_accuracy = load_model_state(model, model_path) print( f"Loaded model. Resuming from - epoch: {epoch} | iteration: {iteration} | best accuracy: {best_accuracy}" ) if not os.path.exists(file_helper.model_checkpoint_path): print("No checkpoint exists. Saving model...\r") torch.save(model.visual_module, file_helper.model_checkpoint_path) print("No checkpoint exists. Saving model...Done") train_data, valid_data, test_data, valid_meta_data, _ = get_training_data( device=device, batch_size=args.batch_size, k=args.k, debugging=args.debugging, dataset_type=args.dataset_type, ) train_meta_data, valid_meta_data, test_meta_data = get_meta_data() pytorch_total_params = sum(p.numel() for p in model.parameters()) if not args.disable_print: # Print info print("----------------------------------------") print("Model name: {} \n|V|: {}\nL: {}".format(model_name, args.vocab_size, args.max_length)) print(sender) if receiver: print(receiver) if diagnostic_receiver: print(diagnostic_receiver) print("Total number of parameters: {}".format(pytorch_total_params)) model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Train current_patience = args.patience best_accuracy = -1.0 converged = False start_time = time.time() if args.test_mode: test_loss_meter, test_acc_meter, _ = train_helper.evaluate( model, test_data, test_meta_data, device, args.rl) average_test_accuracy = test_acc_meter.avg average_test_loss = test_loss_meter.avg print( f"TEST results: loss: {average_test_loss} | accuracy: {average_test_accuracy}" ) return while iteration < args.iterations: for train_batch in train_data: print(f"{iteration}/{args.iterations} \r", end="") # !!! This is the complete training procedure. Rest is only logging! _, _ = train_helper.train_one_batch(model, train_batch, optimizer, train_meta_data, device) if iteration % args.log_interval == 0: if not args.rl: valid_loss_meter, valid_acc_meter, _, = train_helper.evaluate( model, valid_data, valid_meta_data, device, args.rl) else: valid_loss_meter, hinge_loss_meter, rl_loss_meter, entropy_meter, valid_acc_meter, _ = train_helper.evaluate( model, valid_data, valid_meta_data, device, args.rl) new_best = False average_valid_accuracy = valid_acc_meter.avg if (average_valid_accuracy < best_accuracy ): # No new best found. May lead to early stopping current_patience -= 1 if current_patience <= 0: print("Model has converged. Stopping training...") converged = True break else: # new best found. Is saved. new_best = True best_accuracy = average_valid_accuracy current_patience = args.patience save_model_state(model, model_path, epoch, iteration, best_accuracy) metrics = { 'loss': valid_loss_meter.avg, 'accuracy': valid_acc_meter.avg, } if args.rl: metrics['hinge loss'] = hinge_loss_meter.avg metrics['rl loss'] = rl_loss_meter.avg metrics['entropy'] = entropy_meter.avg logger.log_metrics(iteration, metrics) iteration += 1 if iteration >= args.iterations: break epoch += 1 if converged: break return run_folder