def run(proc_id, n_gpus, devices, args):
    set_seed(args.seed)
    dev_id = devices[proc_id]

    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port=args.tcp_port)
        world_size = n_gpus
        torch.distributed.init_process_group(backend="nccl",
                                             init_method=dist_init_method,
                                             world_size=world_size,
                                             rank=dev_id)
    device = torch.device(dev_id)

    dataset = Dataset(
        proc_id=proc_id,
        data_dir=args.save_dir,
        train_fname=args.train_fname,
        preprocessed=args.preprocessed,
        lower=args.lower,
        vocab_max_size=args.vocab_max_size,
        emb_dim=args.emb_dim,
        save_vocab_fname=args.save_vocab_fname,
        verbose=True,
    )
    train_dl, valid_dl, test_dl = \
        dataset.get_dataloader(proc_id=proc_id, n_gpus=n_gpus, device=device,
                               batch_size=args.batch_size)

    validator = Validator(dataloader=valid_dl,
                          save_dir=args.save_dir,
                          save_log_fname=args.save_log_fname,
                          save_model_fname=args.save_model_fname,
                          valid_or_test='valid',
                          vocab_itos=dataset.INPUT.vocab.itos,
                          label_itos=dataset.TGT.vocab.itos)
    tester = Validator(dataloader=test_dl,
                       save_log_fname=args.save_log_fname,
                       save_dir=args.save_dir,
                       valid_or_test='test',
                       vocab_itos=dataset.INPUT.vocab.itos,
                       label_itos=dataset.TGT.vocab.itos)
    predictor = Predictor(args.save_vocab_fname)

    if args.load_model:
        predictor.use_pretrained_model(args.load_model, device=device)
        import pdb
        pdb.set_trace()

        predictor.pred_sent(dataset.INPUT)
        tester.final_evaluate(predictor.model)

        return

    model = LSTMClassifier(emb_vectors=dataset.INPUT.vocab.vectors,
                           emb_dropout=args.emb_dropout,
                           lstm_dim=args.lstm_dim,
                           lstm_n_layer=args.lstm_n_layer,
                           lstm_dropout=args.lstm_dropout,
                           lstm_combine=args.lstm_combine,
                           linear_dropout=args.linear_dropout,
                           n_linear=args.n_linear,
                           n_classes=len(dataset.TGT.vocab))
    if args.init_xavier: model.apply(init_weights)
    model = model.to(device)
    args = model_setup(proc_id, model, args)

    train(proc_id,
          n_gpus,
          model=model,
          train_dl=train_dl,
          validator=validator,
          tester=tester,
          epochs=args.epochs,
          lr=args.lr,
          weight_decay=args.weight_decay)

    if proc_id == 0:
        predictor.use_pretrained_model(args.save_model_fname, device=device)
        bookkeep(predictor, validator, tester, args, dataset.INPUT)