argv: 1: 使用哪个显卡 2: 为0则重新开始训练 否则读取之前的模型 3: 学习率 Adam默认是1e-3 4: batchSize ''' if (len(sys.argv) != 5): print('1: 使用哪个显卡\n' '2: 为0则重新开始训练 否则读取之前的模型\n' '3: 学习率 Adam默认是1e-3\n' '4: batchSize') exit(0) batchSize = int(sys.argv[4]) # 一次读取?张图片进行训练 dReader = bmpLoader.datasetReader(colorFlag=True, batchSize=batchSize) torch.cuda.set_device(int(sys.argv[1])) # 设置使用哪个显卡 if (sys.argv[2] == '0'): # 设置是重新开始 还是继续训练 encNet = EncodeNet(256).cuda().train() decNet = DecodeNet(256).cuda().train() print('create new model') else: encNet = torch.load('../models/encNet_' + sys.argv[0] + '.pkl', map_location='cuda:' + sys.argv[1]).cuda().train() decNet = torch.load('../models/decNet_' + sys.argv[0] + '.pkl', map_location='cuda:' + sys.argv[1]).cuda().train() print('read ../models/' + sys.argv[0] + '.pkl') print(encNet) print(decNet)
argv: 1: 使用哪个显卡 2: 为0则重新开始训练 否则读取之前的模型 3: 学习率 Adam默认是1e-3 4: batchSize ''' if (len(sys.argv) != 5): print('1: 使用哪个显卡\n' '2: 为0则重新开始训练 否则读取之前的模型\n' '3: 学习率 Adam默认是1e-3\n' '4: batchSize') exit(0) batchSize = int(sys.argv[4]) # 一次读取?张图片进行训练 dReader = bmpLoader.datasetReader(colorFlag=True, batchSize=batchSize) torch.cuda.set_device(int(sys.argv[1])) # 设置使用哪个显卡 if (sys.argv[2] == '0'): # 设置是重新开始 还是继续训练 encNet = EncodeNet().cuda().train() decNet = DecodeNet().cuda().train() print('create new model') else: encNet = torch.load('../models/encNet_' + sys.argv[0] + '.pkl', map_location='cuda:' + sys.argv[1]).cuda().train() decNet = torch.load('../models/decNet_' + sys.argv[0] + '.pkl', map_location='cuda:' + sys.argv[1]).cuda().train() print('read ../models/' + sys.argv[0] + '.pkl') print(encNet) print(decNet)