def train(): rank_id = 0 if args.run_distribute: context.set_auto_parallel_context( device_num=args.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) init() rank_id = get_rank() # dataset/network/criterion/optim ds = train_dataset_creator(args.device_id, args.device_num) step_size = ds.get_dataset_size() print('Create dataset done!') config.INFERENCE = False net = ETSNet(config) net = net.set_train() param_dict = load_checkpoint(args.pre_trained) load_param_into_net(net, param_dict) print('Load Pretrained parameters done!') criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE) lrs = lr_generator(start_lr=1e-3, lr_scale=0.1, total_iters=config.TRAIN_TOTAL_ITER) opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4) # warp model net = WithLossCell(net, criterion) if args.run_distribute: net = TrainOneStepCell(net, opt, reduce_flag=True, mean=True, degree=args.device_num) else: net = TrainOneStepCell(net, opt) time_cb = TimeMonitor(data_size=step_size) loss_cb = LossCallBack(per_print_times=10) # set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2) ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory="./ckpt_{}".format(rank_id)) model = Model(net) model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=True, callbacks=[time_cb, loss_cb, ckpoint_cb])
lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32) opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\ weight_decay=config.weight_decay, loss_scale=config.loss_scale) net_with_loss = WithLossCell(net, loss) if args_opt.run_distribute: net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale, reduce_flag=True, mean=True, degree=device_num) else: net = TrainOneStepCell(net_with_loss, opt, sens=config.loss_scale) time_cb = TimeMonitor(data_size=dataset_size) loss_cb = LossCallBack(rank_id=rank) cb = [time_cb, loss_cb] if config.save_checkpoint: ckptconfig = CheckpointConfig( save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size, keep_checkpoint_max=config.keep_checkpoint_max) save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/") ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig) cb += [ckpoint_cb] model = Model(net) model.train(training_cfg.total_epoch, dataset,
if config.pretrain_epoch_size == 0: for item in list(param_dict.keys()): if not (item.startswith('backbone') or item.startswith('rcnn_mask')): param_dict.pop(item) load_param_into_net(net, param_dict) loss = LossNet() lr = Tensor(dynamic_lr(config, rank_size=device_num, start_steps=config.pretrain_epoch_size * dataset_size), mstype.float32) opt = SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, weight_decay=config.weight_decay, loss_scale=config.loss_scale) net_with_loss = WithLossCell(net, loss) if args_opt.run_distribute: net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True, mean=True, degree=device_num) else: net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale) time_cb = TimeMonitor(data_size=dataset_size) loss_cb = LossCallBack() cb = [time_cb, loss_cb] if config.save_checkpoint: ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * dataset_size, keep_checkpoint_max=config.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix='mask_rcnn', directory=config.save_checkpoint_path, config=ckptconfig) cb += [ckpoint_cb] model = Model(net) model.train(config.epoch_size, dataset, callbacks=cb)