constant.args.continue_from) start_epoch = epoch # index starts from zero verbose = constant.args.verbose if loaded_args != None: # Unwrap nn.DataParallel if loaded_args.parallel: logging.info("unwrap from DataParallel") model = model.module # Parallelize the batch if args.parallel: model = nn.DataParallel(model, device_ids=args.device_ids) else: if constant.args.model == "TRFS": model = init_transformer_model(constant.args, label2id, id2label) opt = init_optimizer(constant.args, model, "noam") else: logging.info("The model is not supported, check args --h") loss_type = args.loss if constant.USE_CUDA: model = model.cuda(0) logging.info(model) num_epochs = constant.args.epochs trainer = Trainer() trainer.train(model, train_loader, train_sampler, valid_loader_list, opt, loss_type, start_epoch, num_epochs, label2id, id2label, metrics)
num_workers=args.num_workers) valid_loader_list.append(valid_loader) start_epoch = 0 metrics = None loaded_args = None if args.continue_from != "": logging.info("Continue from checkpoint:" + args.continue_from) model, vocab, opt, epoch, metrics, loaded_args = load_joint_model( args.continue_from) start_epoch = (epoch) # index starts from zero verbose = args.verbose else: if args.model == "TRFS": model = init_transformer_model(args, vocab, is_factorized=args.is_factorized, r=args.r) else: logging.info("The model is not supported, check args --h") loss_type = args.loss if USE_CUDA: model = model.cuda() logging.info(model) num_epochs = args.epochs print("Parameters: {}(trainable), {}(non-trainable)".format( compute_num_params(model)[0], compute_num_params(model)[1]))