def get_data_and_args(): parser = argparse.ArgumentParser(description='PyTorch Sentiment Discovery Classification') parser = add_general_args(parser) parser = add_model_args(parser) parser = add_classifier_model_args(parser) data_config, data_parser, run_classifier_parser, parser = add_run_classifier_args(parser) args = parser.parse_args() args.cuda = torch.cuda.is_available() args.shuffle=False if args.seed is not -1: torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) (train_data, val_data, test_data), tokenizer = data_config.apply(args) args.data_size = tokenizer.num_tokens args.padding_idx = tokenizer.command_name_map['pad'].Id return (train_data, val_data, test_data), tokenizer, args
def main(): parser = argparse.ArgumentParser( description='PyTorch Sentiment-Discovery Language Modeling') parser = add_general_args(parser) parser = add_model_args(parser) data_config, data_parser = add_unsupervised_data_args(parser) args = parser.parse_args() torch.backends.cudnn.enabled = False args.cuda = torch.cuda.is_available() if args.multinode_init: args.rank = int(os.getenv('RANK', 0)) args.world_size = int(os.getenv("WORLD_SIZE", 1)) # initialize distributed process group and set device if args.rank > 0: torch.cuda.set_device(args.rank % torch.cuda.device_count()) if args.world_size > 1: init_method = 'tcp://' if not args.multinode_init: init_method += 'localhost:6000' else: master_ip = os.getenv('MASTER_ADDR', 'localhost') master_port = os.getenv('MASTER_PORT', '6666') init_method += master_ip + ':' + master_port torch.distributed.init_process_group(backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method) # Set the random seed manually for reproducibility. if args.seed is not None and args.seed > 0: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if args.loss_scale != 1 and args.dynamic_loss_scale: raise RuntimeError( "Static loss scale and dynamic loss scale cannot be used together." ) (train_data, val_data, test_data), tokenizer = data_config.apply(args) args.data_size = tokenizer.num_tokens model, optim, LR, LR_Warmer, criterion = setup_model_and_optim( args, train_data, tokenizer) lr = args.lr best_val_loss = None # If saving process intermittently create directory for saving if args.save_iters > 0 and not os.path.exists( os.path.splitext(args.save)[0]) and args.rank < 1: os.makedirs(os.path.splitext(args.save)[0]) # At any point you can hit Ctrl + C to break out of training early. try: total_iters = 0 elapsed_time = 0 skipped_iters = 0 if args.load_optim: total_iters = optim_sd['iter'] skipped_iters = optim_sd['skipped_iter'] for epoch in range(1, args.epochs + 1): if args.rank <= 0: with open(args.save + '.train_lock', 'wb') as f: pass epoch_start_time = time.time() val_loss, skipped_iters = train(epoch, model, optim, train_data, LR, LR_Warmer, criterion, args, total_iters, skipped_iters, elapsed_time) elapsed_time += time.time() - epoch_start_time total_iters += args.train_iters if val_data is not None: print('entering eval') val_loss = evaluate(val_data, model, criterion, args) print('-' * 89) print( '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.4f} | ' 'valid ppl {:8.4f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(min(val_loss, 20)))) print('-' * 89) # Save the model if the validation loss is the best we've seen so far. if (not best_val_loss or val_loss < best_val_loss) and args.rank <= 0: torch.save(model.state_dict(), args.save) best_val_loss = val_loss if args.world_size == 1 or torch.distributed.get_rank() == 0: try: os.remove(args.save + '.train_lock') except: pass # if args.world_size > 1: # torch.distributed.barrier() torch.cuda.synchronize() except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') #while os.path.exists(args.save+'.train_lock'): # time.sleep(1) # Load the best saved model. #if os.path.exists(args.save): # model.load_state_dict(torch.load(args.save, 'cpu')) # if not args.no_weight_norm and args.rank <= 0: # remove_weight_norm(model) # torch.save(model.state_dict(), args.save) if test_data is not None: # Run on test data. print('entering test') test_loss = evaluate(test_data, model, criterion, args) print('=' * 89) print( '| End of training | test loss {:5.4f} | test ppl {:8.4f}'.format( test_loss, math.exp(min(test_loss, 20)))) print('=' * 89)