def train(): args = parse_args() # init multicards training if args.is_distributed: init() args.rank = get_rank() args.group_size = get_group_size() parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.group_size) # dataset dataset = data_generator.SegDataset(image_mean=args.image_mean, image_std=args.image_std, data_file=args.data_file, batch_size=args.batch_size, crop_size=args.crop_size, max_scale=args.max_scale, min_scale=args.min_scale, ignore_label=args.ignore_label, num_classes=args.num_classes, num_readers=2, num_parallel_calls=4, shard_id=args.rank, shard_num=args.group_size) dataset = dataset.get_dataset(repeat=1) # network if args.model == 'deeplab_v3_s16': network = net_factory.nets_map[args.model]('train', args.num_classes, 16, args.freeze_bn) elif args.model == 'deeplab_v3_s8': network = net_factory.nets_map[args.model]('train', args.num_classes, 8, args.freeze_bn) else: raise NotImplementedError('model [{:s}] not recognized'.format( args.model)) # loss loss_ = loss.SoftmaxCrossEntropyLoss(args.num_classes, args.ignore_label) loss_.add_flags_recursive(fp32=True) train_net = BuildTrainNetwork(network, loss_) # load pretrained model if args.ckpt_pre_trained: param_dict = load_checkpoint(args.ckpt_pre_trained) load_param_into_net(train_net, param_dict) # optimizer iters_per_epoch = dataset.get_dataset_size() total_train_steps = iters_per_epoch * args.train_epochs if args.lr_type == 'cos': lr_iter = learning_rates.cosine_lr(args.base_lr, total_train_steps, total_train_steps) elif args.lr_type == 'poly': lr_iter = learning_rates.poly_lr(args.base_lr, total_train_steps, total_train_steps, end_lr=0.0, power=0.9) elif args.lr_type == 'exp': lr_iter = learning_rates.exponential_lr(args.base_lr, args.lr_decay_step, args.lr_decay_rate, total_train_steps, staircase=True) else: raise ValueError('unknown learning rate type') opt = nn.Momentum(params=train_net.trainable_params(), learning_rate=lr_iter, momentum=0.9, weight_decay=0.0001, loss_scale=args.loss_scale) # loss scale manager_loss_scale = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) model = Model(train_net, optimizer=opt, amp_level="O3", loss_scale_manager=manager_loss_scale) # callback for saving ckpts time_cb = TimeMonitor(data_size=iters_per_epoch) loss_cb = LossMonitor() cbs = [time_cb, loss_cb] if args.rank == 0: config_ck = CheckpointConfig( save_checkpoint_steps=args.save_steps, keep_checkpoint_max=args.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck) cbs.append(ckpoint_cb) model.train(args.train_epochs, dataset, callbacks=cbs)
def train(): args = parse_args() cfg = FCN8s_VOC2012_cfg device_num = int(os.environ.get("DEVICE_NUM", 1)) context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, device_target="Ascend", device_id=args.device_id) # init multicards training args.rank = 0 args.group_size = 1 if device_num > 1: parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num) init() args.rank = get_rank() args.group_size = get_group_size() # dataset dataset = data_generator.SegDataset(image_mean=cfg.image_mean, image_std=cfg.image_std, data_file=cfg.data_file, batch_size=cfg.batch_size, crop_size=cfg.crop_size, max_scale=cfg.max_scale, min_scale=cfg.min_scale, ignore_label=cfg.ignore_label, num_classes=cfg.num_classes, num_readers=2, num_parallel_calls=4, shard_id=args.rank, shard_num=args.group_size) dataset = dataset.get_dataset(repeat=1) net = FCN8s(n_class=cfg.num_classes) loss_ = loss.SoftmaxCrossEntropyLoss(cfg.num_classes, cfg.ignore_label) # load pretrained vgg16 parameters to init FCN8s if cfg.ckpt_vgg16: param_vgg = load_checkpoint(cfg.ckpt_vgg16) param_dict = {} for layer_id in range(1, 6): sub_layer_num = 2 if layer_id < 3 else 3 for sub_layer_id in range(sub_layer_num): # conv param y_weight = 'conv{}.{}.weight'.format(layer_id, 3 * sub_layer_id) x_weight = 'vgg16_feature_extractor.conv{}_{}.0.weight'.format( layer_id, sub_layer_id + 1) param_dict[y_weight] = param_vgg[x_weight] # BatchNorm param y_gamma = 'conv{}.{}.gamma'.format(layer_id, 3 * sub_layer_id + 1) y_beta = 'conv{}.{}.beta'.format(layer_id, 3 * sub_layer_id + 1) x_gamma = 'vgg16_feature_extractor.conv{}_{}.1.gamma'.format( layer_id, sub_layer_id + 1) x_beta = 'vgg16_feature_extractor.conv{}_{}.1.beta'.format( layer_id, sub_layer_id + 1) param_dict[y_gamma] = param_vgg[x_gamma] param_dict[y_beta] = param_vgg[x_beta] load_param_into_net(net, param_dict) # load pretrained FCN8s elif cfg.ckpt_pre_trained: param_dict = load_checkpoint(cfg.ckpt_pre_trained) load_param_into_net(net, param_dict) # optimizer iters_per_epoch = dataset.get_dataset_size() lr_scheduler = CosineAnnealingLR(cfg.base_lr, cfg.train_epochs, iters_per_epoch, cfg.train_epochs, warmup_epochs=0, eta_min=0) lr = Tensor(lr_scheduler.get_lr()) # loss scale manager_loss_scale = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001, loss_scale=cfg.loss_scale) model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3") # callback for saving ckpts time_cb = TimeMonitor(data_size=iters_per_epoch) loss_cb = LossMonitor() cbs = [time_cb, loss_cb] if args.rank == 0: config_ck = CheckpointConfig( save_checkpoint_steps=cfg.save_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix=cfg.model, directory=cfg.ckpt_dir, config=config_ck) cbs.append(ckpoint_cb) model.train(cfg.train_epochs, dataset, callbacks=cbs)