def run(proc_id, n_gpus, devices, args): set_seed(args.seed) dev_id = devices[proc_id] if n_gpus > 1: dist_init_method = 'tcp://{master_ip}:{master_port}'.format( master_ip='127.0.0.1', master_port=args.tcp_port) world_size = n_gpus torch.distributed.init_process_group(backend="nccl", init_method=dist_init_method, world_size=world_size, rank=dev_id) device = torch.device(dev_id) dataset = Dataset( proc_id=proc_id, data_dir=args.save_dir, train_fname=args.train_fname, preprocessed=args.preprocessed, lower=args.lower, vocab_max_size=args.vocab_max_size, emb_dim=args.emb_dim, save_vocab_fname=args.save_vocab_fname, verbose=True, ) train_dl, valid_dl, test_dl = \ dataset.get_dataloader(proc_id=proc_id, n_gpus=n_gpus, device=device, batch_size=args.batch_size) validator = Validator(dataloader=valid_dl, save_dir=args.save_dir, save_log_fname=args.save_log_fname, save_model_fname=args.save_model_fname, valid_or_test='valid', vocab_itos=dataset.INPUT.vocab.itos, label_itos=dataset.TGT.vocab.itos) tester = Validator(dataloader=test_dl, save_log_fname=args.save_log_fname, save_dir=args.save_dir, valid_or_test='test', vocab_itos=dataset.INPUT.vocab.itos, label_itos=dataset.TGT.vocab.itos) predictor = Predictor(args.save_vocab_fname) if args.load_model: predictor.use_pretrained_model(args.load_model, device=device) import pdb pdb.set_trace() predictor.pred_sent(dataset.INPUT) tester.final_evaluate(predictor.model) return model = LSTMClassifier(emb_vectors=dataset.INPUT.vocab.vectors, emb_dropout=args.emb_dropout, lstm_dim=args.lstm_dim, lstm_n_layer=args.lstm_n_layer, lstm_dropout=args.lstm_dropout, lstm_combine=args.lstm_combine, linear_dropout=args.linear_dropout, n_linear=args.n_linear, n_classes=len(dataset.TGT.vocab)) if args.init_xavier: model.apply(init_weights) model = model.to(device) args = model_setup(proc_id, model, args) train(proc_id, n_gpus, model=model, train_dl=train_dl, validator=validator, tester=tester, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay) if proc_id == 0: predictor.use_pretrained_model(args.save_model_fname, device=device) bookkeep(predictor, validator, tester, args, dataset.INPUT)