def test_optimizer_serialization(): data, data_shape, label, label_shape = get_input() mlp = MLP() opt = SGD(mlp.parameters(), lr=0.01, momentum=0.9) slots = TensorDict() for param in mlp.parameters(): slots[param] = np.zeros(param.shape).astype(np.float32) pred = mlp(data) loss = F.square_loss(pred, label.reshape(-1, 1)) opt.zero_grad() opt.backward(loss) opt.step() for param in mlp.parameters(): slot = slots[param] slot *= 0.9 slot -= param.grad.numpy() * 0.01 with BytesIO() as fout: save(opt.state_dict(), fout) fout.seek(0) state_dict = load(fout) opt1 = SGD(mlp.parameters(), lr=0.02, momentum=0.8) opt1.load_state_dict(state_dict) data.set_value(np.random.random(data_shape).astype(np.float32)) label.set_value(np.random.randint(0, 10, label_shape)) pred = mlp(data) loss = F.square_loss(pred, label.reshape(-1, 1)) opt1.zero_grad() opt1.backward(loss) orig_params = TensorDict() for param in mlp.parameters(): orig_params[param] = np.copy(param.numpy()) opt1.step() for param in mlp.parameters(): orig_param = orig_params[param] slot = slots[param] slot *= 0.9 slot -= param.grad.numpy() * 0.01 assertTensorClose(param.numpy(), orig_param + slot)
def worker(args): current_network = import_from_file(args.file) model = current_network.Net(current_network.Cfg()) model.train() if dist.get_rank() == 0: logger.info(get_config_info(model.cfg)) logger.info(repr(model)) backbone_params = [] head_params = [] for name, param in model.named_parameters(): if "backbone" in name: backbone_params.append(param) else: head_params.append(param) opt = SGD( [ { "params": backbone_params, "lr": model.cfg.learning_rate * 0.1 }, { "params": head_params }, ], lr=model.cfg.learning_rate, momentum=model.cfg.momentum, weight_decay=model.cfg.weight_decay * dist.get_world_size(), ) gm = GradManager() if dist.get_world_size() > 1: gm.attach(model.parameters(), callbacks=[dist.make_allreduce_cb("SUM", dist.WORLD)]) else: gm.attach(model.parameters()) cur_epoch = 0 if args.resume is not None: pretrained = mge.load(args.resume) cur_epoch = pretrained["epoch"] + 1 model.load_state_dict(pretrained["state_dict"]) opt.load_state_dict(pretrained["opt"]) if dist.get_rank() == 0: logger.info("load success: epoch %d", cur_epoch) if dist.get_world_size() > 1: dist.bcast_list_(model.parameters(), dist.WORLD) # sync parameters if dist.get_rank() == 0: logger.info("Prepare dataset") train_loader = iter( build_dataloader(model.cfg.batch_size, args.dataset_dir, model.cfg)) for epoch in range(cur_epoch, model.cfg.max_epoch): train_one_epoch(model, train_loader, opt, gm, epoch) if dist.get_rank() == 0: save_path = "log-of-{}/epoch_{}.pkl".format( os.path.basename(args.file).split(".")[0], epoch) mge.save( { "epoch": epoch, "state_dict": model.state_dict(), "opt": opt.state_dict() }, save_path) logger.info("dump weights to %s", save_path)